🧠

微分幾何学を用いた次世代ニューラルネットワーク

に公開

はじめに

これは、広島大学 HiCoder & ゲーム制作同好会GSD Advent Calendar 2025 の12/10の記事 です。

以下の「ニューラル微分多様体:明示的な幾何学的構造を持つアーキテクチャ」論文が非常に面白かったので紹介します。
https://arxiv.org/abs/2510.25113

想定読者は、機械学習の基礎知識を理解している方です。

現在の機械学習と対称性の潮流

近年の機械学習の分野では、ニューラルネットワーク(NN)が非常に高い性能を発揮しています。
これらのモデルは、大量のデータからパターンを学習し、画像認識、自然言語処理など多様なタスクで成功を収めています。
しかし、これらのモデルは大規模化することで性能を向上させることが一般的であり、その結果、 計算資源の増大、金銭や時間コスト増加、環境負荷 が問題になっています。
そこで、対称性を活用することで、モデルの効率性と性能を向上させる研究が進められています。
例えば、畳み込みニューラルネットワーク(CNN)[1]並進対称性 を利用します。画像が水平方向や垂直方向に移動しても、同じ特徴が適切に検出されるため、並進した画像をデータ拡張で大量に追加する必要がありません。
CNNの並進対称性

ただし、従来のNN(全結合層など)はユークリッド空間(平坦な空間)を仮定しているため、データが持つ複雑な幾何学的構造(曲率など)を捉えるのに非効率です。

本論文が提案する Neural Differential Manifold(NDM) は、この問題に対してアプローチしています。
NDM は、ネットワーク内部に「対称性を持つ幾何学構造」を直接組み込み、より少ないパラメータとデータで高い性能を達成できる可能性 を示します。

微分幾何学の基礎

微分幾何学は、曲線や曲面などの滑らかな多様体[2]の性質を研究する数学の一分野です。
応用例として、一般相対論における時空の膨張や曲率の記述などがあります。

この分野において、多様体上の幾何学的構造を特徴づける核となる概念は以下の4つです。

  • 計量(Metric)
  • 共変微分 (Covariant Derivative)
  • 曲率(Curvature)
  • 体積(Volume)

計量(Metric)

計量 g は、多様体上での距離と角度を定義します。

ds^2 = g_{ij} dx^i dx^j

そして、ここに多様体の幾何的構造の情報がすべて組み込まれています。
NDM では、この計量をネットワークが学習し、意味的距離を内部表現空間に形成できる点が重要です。

共変微分 (Covariant Derivative)

多様体 M にリーマン計量 g が導入されると、共変微分 \nabla が、この幾何学の最も基本的な微分となります。
これは、多様体という曲がった空間でベクトルやテンソルを微分し、変化率を正確に測定するために不可欠です。

共変微分は、多様体の曲がりによって座標基底が変化する影響を取り除き、ベクトル場の真の変化率を測定する操作 です。
以下の図は、2つの異なる座標系で定義された単位ベクトルが異なる向きを持つことを示してしており、これを微分する際に考慮する必要があります。
2つの座標での単位ベクトルの違い

ベクトル場 A^\mu の共変微分は次式で与えられます。

\nabla_\sigma A^\mu = \frac{\partial A^\mu}{\partial x^\sigma} + \Gamma^\mu_{\sigma\alpha} A^\alpha
  • \partial A^\mu / \partial x^\sigma … 成分の通常の微分
  • \Gamma^\mu_{\sigma\alpha} … 座標系の「曲がり」を補正するクリストッフェル記号(接続係数)。

曲率(Curvature)

曲率は、多様体がどれだけ曲がっているかを表します。
基本的な量としてリーマン曲率テンソルがあります。

R_{ijk}^{\ \ l} = \frac{\partial \Gamma^l_{jk}}{\partial x^i} - \frac{\partial \Gamma^l_{ik}}{\partial x^j} + \Gamma^m_{jk} \Gamma^l_{mi} - \Gamma^m_{ik} \Gamma^l_{mj}

曲率テンソルを縮約すると、

  • リッチ曲率 R_{ij}
  • スカラー曲率 R

が得られます。

体積(Volume)

計量の行列式から体積要素 \sqrt{\det(g)} が定義できます。

dV = \sqrt{\det(g)} \ dx^1 \wedge dx^2 \wedge \dots \wedge dx^n

Neural Differential Manifold (NDM) アーキテクチャ

NDM は次の3層構造で記述されます。

  1. Coordinate Layer(座標層): 局所座標系間の遷移(層間の変換)
  2. Geometric Layer(幾何学層): 計量の動的な生成
  3. Evolution Layer(進化層): 幾何学的正則化を含む最適化

それぞれの役割と数理的な仕組みを見ていきましょう。

1. Coordinate Layer(座標層):局所座標系間の遷移

NDMにおいて、各層 L_i の活性化状態は、多様体 M 上の局所座標系(U_i, x_i) とみなされます。
従来のNNでは層間の接続は単なる行列演算と非線形変換ですが、NDMではこれを「滑らかな座標変換」として再定義します。

隣接する層 L_i から L_j への変換 \phi_{i \to j} は、微分同相写像[3]が理想的です。

x_j = \phi_{i \to j}(x_i)

実用的には、この \phiNormalizing Flows[4]の技術を用いて実装されます。これにより、情報の損失を防ぎつつ(可逆性)、滑らかに内部表現を変換することが可能になります。
ネットワーク全体の順伝播は、これら座標変換の合成として表現され、データポイントが多様体上を移動する様子を捉えます。

x_{output} = \phi_{L \to L-1} \circ \dots \circ \phi_{1 \to 2}(x_{input})

2. Geometric Layer(幾何学層):計量の動的生成

ここがNDMの最もコアな部分です。多様体の形(距離や角度)を決める計量テンソル g は、あらかじめ固定されたものではなく、ネットワーク自身が学習によって動的に生成します。

各層 L_i には、補助的なサブネットワークである Metric Net (M_i) が付随します。
Metric Net は、その層の局所座標(活性化値)x_i を入力とし、その点における計量テンソル g_{ij}(x_i) を出力します。

計量テンソルは正定値対称行列である必要があるため、Metric Netは通常、下三角行列 L を出力し、コレスキー分解の逆操作のような形で計量を構成します(\epsilonは数値安定性のための微小項)。

g(x_i) = L(x_i)L(x_i)^T + \epsilon I

これにより、ネットワークの内部表現空間における距離が定義されます。同じユークリッド距離を持つ2つの点でも、Metric Netが学習した幾何学上では意味的に遠い(または近い)と解釈されるようになります。

高次元を扱うとイメージしづらいので、2次元多様体上の座標変換の例を示します。

NDMにおける座標変換のイメージ

3. Evolution Layer(進化層):二重目的の最適化

進化層は物理的な層ではなく、学習プロセスを統括する概念的な層です。[5]ここでは、タスクの性能向上だけでなく、「幾何学的な単純さ」 を保つように多様体を進化させます。

損失関数 L_{\text{total}} は、タスク損失 L_{\text{task}} と幾何学的正則化項 L_{\text{geo}} の和で表されます。

L_{total} = L_{task}(\theta) + \lambda L_{geo}(g(\theta))

幾何学的正則化 (L_{\text{geo}})

論文では、過学習を防ぎ、汎化性能を高めるために、以下の2つの幾何学的ペナルティを提案しています。

  1. 曲率正則化 (L_{\text{curv}}):
    多様体が複雑に歪むことが過学習を生むので、リッチ曲率スカラー R の大きさをペナルティとします。これにより、表現空間は可能な限り平坦になろうとします。スカラー曲率 R を求めるには、リーマン曲率テンソルの縮約が必要であり、テンソル計算の計算量が次元数に対して非常に重くなります。 (計算量はO(n^4), O(n^3)

    L_{curv} = \mathbb{E}_{p \in M} [R(p)^2]

  2. 体積正則化 (L_{\text{vol}}):
    層を通過するごとの体積要素 \sqrt{\det(g)} の急激な変化や偏りを抑制します。これにより数値的な安定性と表現の効率性を保ちます。

    L_{vol} = \text{Var}(\sqrt{\det(g(p))})

自然勾配降下法による学習

NDMの真価は、最適化手法にあります。
計量 g が定義されているため、通常の勾配降下法ではなく、自然勾配降下法 (Natural Gradient Descent) を適用することとなります。

通常のパラメータ更新 \theta \leftarrow \theta - \eta \nabla L は、パラメータ空間がユークリッド的(平坦)であると仮定しています。
しかし、NDMではパラメータ空間の幾何学構造 G(\theta) (フィッシャー情報行列)[6]を考慮し、以下のように更新します。

\tilde{\nabla} L = G(\theta)^{-1} \nabla L
\theta \leftarrow \theta - \eta \tilde{\nabla} L

通常、自然勾配法における G(\theta) は、パラメータ空間のを指します。
一方で、NDMにおける g(Metric Netが出力するもの)は、データ表現空間(多様体) の計量です。

これにより、多様体上の実質的な変化量 に基づいて、最短経路で最適解へ向かうことが可能になります。

実験結果

ソースコードがなかったので、AIを使い論文から推測して実装しました[7]
曲率正則化 L_{\text{curv}} の実装はできる気がしなかったので、実装していません。
その結果、多様体が異常に曲がって過学習している傾向が見えています。(Acc: 1.00)
さらに、今回2次元データを用いたため、高次元に拡張するには今の実装のままではできません。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

# ==========================================
# 1. Model Architecture (NDM)
# ==========================================

class MetricNet(nn.Module):
    """
    入力座標 x に基づいて、その点における計量テンソル g を生成する。
    正定値性を保証するため L @ L.T 形式を使用。
    """
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        self.input_dim = input_dim
        # 下三角行列の要素数
        self.output_dim = input_dim * (input_dim + 1) // 2

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, self.output_dim)
        )

    def forward(self, x):
        batch_size = x.size(0)
        l_params = self.net(x)

        # 下三角行列 L を構築
        L = torch.zeros(batch_size, self.input_dim, self.input_dim, device=x.device)
        indices = torch.tril_indices(self.input_dim, self.input_dim)
        L[:, indices[0], indices[1]] = l_params

        # 対角成分を正にする (Softplus)
        diagonal_indices = torch.arange(self.input_dim)
        L[:, diagonal_indices, diagonal_indices] = F.softplus(L[:, diagonal_indices, diagonal_indices])

        # 計量 g = L L^T + epsilon
        g = torch.bmm(L, L.transpose(1, 2)) + torch.eye(self.input_dim, device=x.device) * 1e-6
        return g, L

class ManifoldLayer(nn.Module):
    """
    座標変換(Coordinate) + 計量生成(Geometric)
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.coordinate_map = nn.Linear(in_features, out_features)
        self.metric_net = MetricNet(out_features)

    def forward(self, x):
        # 1. 座標変換 z = f(x)
        z = F.elu(self.coordinate_map(x))

        # 2. 幾何構造 g(z)
        g, L = self.metric_net(z)

        # 体積要素の計算 (正則化用)
        # log_det = 2 * sum(log(diag(L)))
        diag_L = torch.diagonal(L, dim1=-2, dim2=-1)
        log_det_g = 2 * torch.sum(torch.log(diag_L + 1e-8), dim=1)
        volume_element = torch.exp(0.5 * log_det_g)

        return z, g, volume_element

class NeuralDifferentialManifold(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        self.layers = nn.ModuleList()

        dims = [input_dim] + hidden_dims

        for i in range(len(dims) - 1):
            self.layers.append(ManifoldLayer(dims[i], dims[i+1]))

        self.final_layer = nn.Linear(hidden_dims[-1], output_dim)

    def forward(self, x):
        geometry_stats = []

        # 各層を通過
        for layer in self.layers:
            x, g, vol = layer(x)
            geometry_stats.append({
                'metric': g,
                'volume': vol
            })

        output = self.final_layer(x)
        return output, geometry_stats

# ==========================================
# 2. Utility for True Geometry Visualization
# ==========================================

def compute_pullback_volume(model, grid_tensor):
    """
    モデル自体が定義する「引き戻し計量 (Pullback Metric)」の体積要素を計算
    """
    model.eval()
    volumes = []

    # バッチ処理ではなく1点ずつ計算 (autograd.gradのため)
    # 高速化のためには vmap などを使うが、ここでは分かりやすさ優先
    for point in grid_tensor:
        point = point.unsqueeze(0).clone().detach().requires_grad_(True)

        # モデルの出力 (Logits -> Prob)
        logits, _ = model(point)
        output = torch.sigmoid(logits)

        # ヤコビ行列 J = d(output)/d(input)
        # 入力(2次元) -> 出力(1次元) なので J は勾配ベクトルそのもの
        grad_params = torch.autograd.grad(output, point, create_graph=False)[0]

        # J のノルム(勾配の大きさ)が「局所的な空間の伸縮率」に相当
        # Strictly speaking: sqrt(det(J^T J)) but for R^2->R^1 map, it's just norm.
        vol = torch.norm(grad_params)

        volumes.append(vol.item())

    return np.array(volumes)

# ==========================================
# 3. Main Training & Visualization Loop
# ==========================================

def run_experiment():
    # --- Data Preparation ---
    print("Generating Two Moons data...")
    X, y = make_moons(n_samples=500, noise=0.1, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    X_train_t = torch.FloatTensor(X_train)
    y_train_t = torch.FloatTensor(y_train).unsqueeze(1)

    # --- Model Setup ---
    # 入力2次元 -> 隠れ層16 -> 出力1次元
    model = NeuralDifferentialManifold(input_dim=2, hidden_dims=[16], output_dim=1)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Hyperparameters
    LAMBDA_GEO = 0.05  # 幾何正則化の強さ
    EPOCHS = 600

    # --- Training Loop ---
    print(f"Starting training for {EPOCHS} epochs...")
    for epoch in range(EPOCHS):
        model.train()
        optimizer.zero_grad()

        # Forward Pass
        output, geo_stats = model(X_train_t)

        # 1. Task Loss (BCE)
        task_loss = F.binary_cross_entropy_with_logits(output, y_train_t)

        # 2. Geometric Regularization (Volume Variance)
        # 空間の歪みが極端にならないように分散を抑える
        vol_loss = 0
        for stats in geo_stats:
            vol_loss += torch.var(stats['volume'])

        # Total Loss
        total_loss = task_loss + LAMBDA_GEO * vol_loss

        total_loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            # Accuracy
            preds = (torch.sigmoid(output) > 0.5).float()
            acc = (preds == y_train_t).float().mean()
            print(f"Epoch {epoch:03d} | Loss: {total_loss.item():.4f} (Task: {task_loss.item():.4f}, Geo: {vol_loss.item():.4f}) | Acc: {acc:.2f}")

    print("Training finished.")

    # --- Visualization ---
    print("Computing visualization maps...")

    # Create Grid
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.05),
                         np.arange(y_min, y_max, 0.05))

    grid_tensor = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()])

    # 1. Decision Boundary (Predictions)
    with torch.no_grad():
        model.eval()
        logits, _ = model(grid_tensor)
        probs = torch.sigmoid(logits).reshape(xx.shape)

    # 2. True Geometric Structure (Sensitivity / Pullback Volume)
    volumes = compute_pullback_volume(model, grid_tensor)
    volumes = volumes.reshape(xx.shape)

    # --- Plotting ---
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # Plot 1: Decision Boundary
    ax1.set_title("Decision Boundary (Probability Field)")
    contour1 = ax1.contourf(xx, yy, probs, levels=20, cmap="RdBu", alpha=0.8)
    ax1.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolors='k', cmap="RdBu_r", s=30)
    fig.colorbar(contour1, ax=ax1)
    ax1.set_xlabel("Feature 1")
    ax1.set_ylabel("Feature 2")

    # Plot 2: Induced Geometry (Sensitivity)
    # 明るい部分 = 空間が引き伸ばされている = 情報密度が高い = 決定境界
    ax2.set_title("Induced Geometry: $\det(g)$")
    contour2 = ax2.contourf(xx, yy, volumes, levels=20, cmap="magma", alpha=0.9)
    ax2.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolors='k', cmap="RdBu_r", s=30, alpha=0.3)
    fig.colorbar(contour2, ax=ax2)
    ax2.set_xlabel("Feature 1")
    ax2.set_ylabel("Feature 2")

    plt.tight_layout()
    plt.show()

run_experiment()

実行結果は以下のようになります。
確率の境界と体積要素
赤と青のデータの境界付近は、体積要素 \sqrt{\det(g)} が大きくなっていることが分かります。
この境界付近では、モデルが情報を多く必要とするため、空間が引き伸ばされている(感度が高い)ことを示しています。

NDMの利点と可能性

NDMがもたらす利点は、

  1. 解釈可能性(Explainability):
    曲率が高い領域が決定境界に対応するなど、内部表現の幾何学的な分析が可能
  2. 継続学習(Continual Learning):
    新しいタスクを学習する際、幾何構造の歪みを検知して局所的に適応することで、過去の知識の忘却(破滅的忘却)を防げる可能性
  3. 科学的発見への応用:
    対称性や保存則を持つデータを扱う際、幾何構造として自然に学習できる。物理や化学のデータは幾何的制約や対称性を従うことが多いため、NDMはこれらのドメインで特に有効であると期待

NDMの欠点と課題

  1. 計算コストとメモリ:
    曲率計算(二階微分)や逆行列演算の負荷が高く、メモリ消費も層の幅の二乗で増えるため、大規模化が難しい
  2. 数値的な不安定性:
    計量テンソルの正定値性の維持や逆行列の計算は数値的に不安定になりやすく、学習を安定させるための高度な実装技術(正則化や分解手法)が必要

まとめ

ニューラル微分多様体(NDM)は、ディープラーニングと微分幾何学の融合する試みです。
計算コストの増大(計量テンソルの計算と逆行列演算)や数値的安定性といった課題は残されていますが、ブラックボックスになりがちなニューラルネットワークに明示的な構造を与えるこのアプローチは、次世代のAIアーキテクチャの重要な指針となるかもしれません。

脚注
  1. ここではプーリング操作を除外した純粋な畳み込みを念頭に置く ↩︎

  2. 滑らかな多様体とは、微分可能な空間のこと ↩︎

  3. 2 つの多様体 M と N が与えられたとき、可微分写像 f: M → N は全単射かつ逆写像 f−1: N → M も可微分なとき微分同相写像 ↩︎

  4. 複雑な関数を単純な関数の合成関数で表現する技術 ↩︎

  5. 層だと積み重なっているようなニュアンスになるが、あまりそうではない ↩︎

  6. 確率分布の計量 ↩︎

  7. NN専用のライブラリ使いすぎで、正しいか心配です。詳しい人、ライブラリを作ってください ↩︎

Discussion