🐍

誤差逆伝播法の数式を幾何学で解釈する

2024/04/11に公開

はじめに

機械学習の勉強をしていて次のような公式を知った。

\frac{\partial L}{\partial \mathbf{X}} = \mathbf{W}^\top \cdot \frac{\partial L}{\partial \mathbf{Y}} \quad (L = L(\mathbf{Y}), \mathbf{Y} = \mathbf{W} \cdot \mathbf{X} + \mathbf{B})

証明は簡単だが、\mathbf{W} が転置になることの説明を考えていたところ幾何学の言葉(微分形式)を使うと見通しがいいことに気づいたので、自分用にまとめた。

ベクトル・行列の微分公式

L = L(\mathbf{Y})

\mathbf{Y} = \left( \begin{matrix} y_1 \\ \vdots \\ y_n \\ \end{matrix}\right)

の関数、\mathbf{X}

\mathbf{X} = \left( \begin{matrix} x_1 \\ \vdots \\ x_m \\ \end{matrix}\right)

の関数として行列

\mathbf{W} = \left( \begin{matrix} w_{11} & \dots & w_{1m} \\ \vdots & & \vdots \\ w_{n1} & \dots & w_{nm} \\ \end{matrix}\right)

とベクトル

\mathbf{B} = \left( \begin{matrix} b_1 \\ \vdots \\ b_m \\ \end{matrix}\right)
\mathbf{Y} = \mathbf{W} \cdot \mathbf{X} + \mathbf{B}

と表されているとする。このとき、L\mathbf{Y}, \mathbf{X}, \mathbf{W} で成分ごとに微分して得られるベクトル・行列

\begin{align} \frac{\partial L}{\partial \mathbf{Y}} &=\left( \begin{matrix} \frac{\partial L}{\partial y_1} \\ \vdots \\ \frac{\partial L}{\partial y_n} \\ \end{matrix}\right) \\ \frac{\partial L}{\partial \mathbf{Y}} &= \left( \begin{matrix} \frac{\partial L}{\partial x_1} \\ \vdots \\ \frac{\partial L}{\partial x_m} \\ \end{matrix}\right) \\ \frac{\partial L}{\partial \mathbf{W}} &= \left( \begin{matrix} \frac{\partial L}{\partial w_{11}} & \dots &\frac{\partial L}{\partial w_{1m}} \\ \vdots & & \vdots \\ \frac{\partial L}{\partial w_{n1}} & \dots &\frac{\partial L}{\partial w_{nm}} \\ \end{matrix}\right) \end{align}

の間には以下の関係がある。

\begin{align} \frac{\partial L}{\partial \mathbf{X}} &= \mathbf{W}^\top \cdot \frac{\partial L}{\partial \mathbf{Y}} \\ \frac{\partial L}{\partial \mathbf{W}} &= \frac{\partial L}{\partial \mathbf{Y}} \cdot \mathbf{X}^\top \end{align}

これらの公式はニューラルネットワークの学習の基本となる誤差逆伝播法の中に頻出する。

微分形式を用いた解釈

M = \mathbb{R}^m, N = \mathbb{R}^n とおき、関数 f \colon M \to Nf(\mathbf{X}) = \mathbf{W} \cdot \mathbf{X} + \mathbf{b} とおく。
このとき \frac{\partial L}{\partial \mathbf{Y}}, \frac{\partial L}{\partial \mathbf{X}} はそれぞれ微分形式

\begin{align} dL &= \sum_{j = 1}^n \frac{\partial L}{\partial y_j} dy_j\\ f^*(dL) &= d(L \circ f) = \sum_{i = 1}^m \frac{\partial L}{\partial x_i} dx_i\\ \end{align}

と同一視できる。より正確にいうと、座標の取り方によらない概念である dL, f^*(dL) を、座標に付随した余接束の局所枠(local frame, ベクトル束のファイバーごとの基底を与える局所切断の組のこと)である (dy_j)_j, (dy_i)_i を使って表示したものが \frac{\partial L}{\partial \mathbf{Y}}, \frac{\partial L}{\partial \mathbf{X}} だと解釈できる。

よって \frac{\partial L}{\partial \mathbf{Y}} を使って \frac{\partial L}{\partial \mathbf{X}} を表示することは、f による微分形式の引き戻し射

f^* \colon \Gamma(N, T^*N) \to \Gamma(M, T^*M)

を枠 (dy_j)_j, (dy_i)_i を使って表示することに他ならない。各点ごとに見ればいいから p \in M を固定すると、線形写像

f^* \colon T_{f(p)}^*N \to T_p^*M

を基底 (dy_j)_j, (dy_i)_i で行列表示するという話になる。この写像の双対が接ベクトルの押し出し射(= f の微分)

f_* \colon T_pM \to T_{f(p)}N

であること、それを基底 (\partial/\partial x_i)_i, (\partial/\partial y_j)_j で表示したときの行列が \mathbf{W} になること(今回 f がアファイン関数だから、その微分は1次の係数そのものである!)、一般に双対射を双対基底で表示すると元の射の表現行列の転置が現れること、を合わせることで求める行列は \mathbf{W}^\top とわかる。よって

\frac{\partial L}{\partial \mathbf{X}} = \mathbf{W}^\top \cdot \frac{\partial L}{\partial \mathbf{Y}}

が成り立つ。2つ目の

\frac{\partial L}{\partial \mathbf{W}} = \frac{\partial L}{\partial \mathbf{Y}} \cdot \mathbf{X}^\top

は1つ目から(ちょっと工夫すると)出る。

まとめ

  • \frac{\partial L}{\partial \mathbf{Y}}, \frac{\partial L}{\partial \mathbf{X}} は座標を固定したときの微分形式の表示である。
  • それらの間の関係は微分形式の引き戻しだから、接ベクトルの押し出しの行列表示である \mathbf{W} の双対として転置 \mathbf{W}^\top が現れる。

Discussion