📝

行列積の誤差逆伝搬について

に公開

はじめに

最近ニューラルネットワーク関連の資料や論文を読んでいて、さらっと書かれていた行列積の誤差逆伝搬の公式でつまづいてしまうことがありました。
行列の記法による簡潔な表記は議論の上では非常に便利なのですが、一方で慣れていないと非直観的に感じられることもあるように思われます。行列の成分単位での計算に立ち戻って考えたほうが理解はしやすいです。
一方で、成分計算はそれなりに煩雑で、暗算で考えるのはちょっと面倒だったりします。

今回は頻出の行列積に対する誤差逆伝搬の計算について、自分なりに納得できるように導出を整理してみました。考え方をすぐに思い出せるようにメモしておきたいと思います。

行列積の誤差逆伝搬に関する具体的な公式を整理することが主な目的なので、誤差逆伝搬自体の考え方についてはごく簡易な導入で済ませることにします。

前提とする知識

  • 大学初年次程度の微(積)分
  • 誤差逆伝搬の基本的な枠組みについて

記法

  • 損失関数 (loss) を一貫して \mathcal{L} で表します。
  • \mathbf{X} = (X_1, \ldots, X_n) に関する f の勾配を \frac{\partial f}{\partial \mathbf{X}} = (\frac{\partial f}{\partial X_1}, \ldots, \frac{\partial f}{\partial X_n}) で表します。
  • 損失関数 \mathcal{L}\mathbf{X} に関する勾配を \mathrm{d}\mathbf{X} \equiv \frac{\partial \mathcal{L}}{\partial \mathbf{X}} と表すことにします[1]
  • Kronecker delta \delta_{ij} を次のように定義します:
\delta_{ij} = \begin{cases} 1 & \text{if } i = j \\ 0 & \text{if } i \neq j \end{cases}

この記事で示すこと

\mathbf{X} \in \mathbb{R}^{m \times k}, \mathbf{Y} \in \mathbb{R}^{k \times n}, \mathbf{Z} \in \mathbb{R}^{m \times n} として、順伝搬 (forward pass) で以下の計算が行われるとします。

\mathbf{Z = XY}

このとき、 \mathbf{Z} に関する勾配 \mathrm{d}\mathbf{Z} \in \mathbb{R}^{m \times n} を用いて誤差逆伝搬の計算は以下で行うことができます。

\mathrm{d}\mathbf{X} = \mathrm{d}\mathbf{Z}\, \mathbf{Y}^T \quad \text{where } \mathrm{d}\mathbf{X} \in \mathbb{R}^{m \times k}
\mathrm{d}\mathbf{Y} = \mathbf{X}^T \mathrm{d}\mathbf{Z} \quad \text{where } \mathrm{d}\mathbf{Y} \in \mathbb{R}^{k \times n}

誤差逆伝搬 (Backpropagation) について

ニューラルネットワークの学習の基本となる考え方は、「(入力) パラメータを少し変化させて、損失関数 (Loss) を減少させる」というステップを繰り返すことです。このためには、各入力パラメータ \mathbf{X} (e.g. 重み, バイアス) に関して、 loss を減少させるようなパラメータの更新方法を考えれば良いです。これはパラメータ \mathbf{X} に関する loss の勾配 (gradient) \frac{\partial \mathcal{L}}{\partial \mathbf{X}} を求めて、勾配のマイナス方向にパラメータ \mathbf{X} を更新することで達成できます。

このように、ニューラルネットワークの学習においては各パラメータに関する loss の勾配を求めるという計算が重要になります。勾配は微分に関する chain rule (後述) を用いて出力側から入力側に向かって勾配の値を伝搬させていくような形で計算することができます。この勾配計算は推論時と逆向きの流れで計算が進むので、一般に誤差逆伝搬 (backpropagation) と呼ばれます。

連鎖律 (Chain Rule)

いま、 x, z \in \mathbb{R} について

z = f(x)

という関係が成り立っているとします。
連鎖律 (chain rule) は出力側の変数 z に関する勾配 \frac{\partial \mathcal{L}}{\partial z} から入力側の変数 x に関する勾配 \frac{\partial \mathcal{L}}{\partial x} を計算する式です。

\frac{\partial \mathcal{L}}{\partial x} = \frac{\partial \mathcal{L}}{\partial z} \frac{\partial z}{\partial x}

誤差逆伝搬は出力側の勾配から入力側の勾配を求める計算であるため、 chain rule が非常に重要な役割を果たします。

ベクトルや行列の場合も基本は同じで、各成分に対する chain rule を考えれば良いです。すなわち、 \mathbf{Z} = f(\mathbf{X}) (\mathbf{X} \in \mathbb{R}^m, \mathbf{Z} \in \mathbb{R}^n) としたとき、以下のような chain rule を考えることになります。

\frac{\partial \mathcal{L}}{\partial X_j} = \sum_{i} \frac{\partial \mathcal{L}}{\partial Z_i} \frac{\partial Z_i}{\partial X_j} \\

あるいは loss に対する勾配の簡略表記を用いて、以下のように表現します。

\mathrm{d} X_j = \sum_{i} \mathrm{d} Z_i \frac{\partial Z_i}{\partial X_j}

図式的な説明

\mathbf{X} に関する勾配について考えます。
いきなり全ての成分を考えるのは大変なので、特定の成分 X_{st} に注目して考えます。

重要になるのは以下の chain rule です。

\frac{\partial \mathcal{L}}{\partial X_{st}} = \sum_{i, j} \frac{\partial \mathcal{L}}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial X_{st}}

これは直観的には、注目する成分 X_{st} によって影響を受ける \mathbf{Z} の各成分 Z_{ij} に関する勾配 \frac{\partial \mathcal{L}}{\partial Z_{ij}} に影響の度合い \frac{\partial Z_{ij}}{\partial X_{st}} を掛け、足し合わせたものと理解できます。

\mathbf{Z} = \mathbf{X}\mathbf{Y} の計算において、 X_{st} が関係する部分を以下の図に示します。

\mathbf{Z} のうち、 X_{st} によって影響を受けるのは s 行目の各成分 Z_{sj} です。
Z_{sj}X_{s\cdot}Y_{\cdot j} (灰色の部分) の dot 積で計算され、うち X_{st} が関わるのは X_{st}Y_{tj} の項だけです。したがって、以下が成り立ちます。

\frac{\partial Z_{sj}}{\partial X_{st}} = Y_{tj}

j を動かして Z_{s1}, \ldots, Z_{sn} について同様に考えると、 \mathrm{d}X_{st} について以下が成り立つことがわかります。

\mathrm{d}X_{st} = \sum_j \mathrm{d}Z_{sj} \frac{\partial Z_{sj}}{\partial X_{st}} = \sum_j \mathrm{d}Z_{sj} Y_{tj}

これは、図中ではオレンジ色の 2 つの行の dot 積に相当する部分として考えることができます (厳密には \mathbf{Z}\mathrm{d}\mathbf{Z} を読みかえます)。
行どうしの dot 積だとうまく行列積として表現することはできませんが、ここで Y を転置してあげると、

\mathrm{d}X_{st} = \sum_j \mathrm{d}Z_{sj} (\mathbf{Y}^T)_{jt}

となって、以下のように行列積として表現できることがわかります。

\mathrm{d}\mathbf{X} = \mathrm{d}\mathbf{Z}\, \mathbf{Y}^T

同様に \mathbf{Y} の特定の成分 Y_{st} に関する勾配を考えます。

同様に考えると、

\frac{\partial Z_{it}}{\partial Y_{st}} = X_{is}

が成り立ち、 i 方向に全ての成分を考えることで以下の式が成り立ちます。

\mathrm{d}Y_{st} = \sum_i \mathrm{d}Z_{it} \frac{\partial Z_{it}}{\partial Y_{st}} = \sum_i \mathrm{d}Z_{it} X_{is}

右辺の dot 積は図中ではオレンジ色の 2 つの列に相当する部分です。こちらも以下のように簡潔な行列表現を与えることができます。

\mathrm{d}Y_{st} = \sum_i \mathrm{d}Z_{it} X_{is} = \sum_i (\mathbf{X}^T)_{si} \mathrm{d}Z_{it}
\mathrm{d}\mathbf{Y} = \mathbf{X}^T \mathrm{d}\mathbf{Z}

数式による導出

やっていることは上述の図式的な説明と実質的に同じですが、より形式的に、数式による導出を簡潔に示しておきます。

\mathbf{X} に関する勾配

\mathbf{X} は 2 次元行列なので、成分 X_{st} ごとに勾配を考えます。

\begin{aligned} \mathrm{d}X_{st} &= \sum_{i, j} \mathrm{d}Z_{ij} \frac{\partial Z_{ij}}{\partial X_{st}} \quad (\because \text{chain rule})\\ &= \sum_{i, j} \mathrm{d}Z_{ij} \frac{\partial}{\partial X_{st}} \left( \sum_{l} X_{il} Y_{lj} \right) \quad (\because \mathbf{Z} = \mathbf{X}\mathbf{Y})\\ &= \sum_{i, j} \mathrm{d}Z_{ij} \sum_{l} \left( \frac{\partial X_{il}}{\partial X_{st}} \right) Y_{lj} \\ &= \sum_{i, j} \mathrm{d}Z_{ij} \sum_{l} \delta_{is} \delta_{lt} Y_{lj} \quad (\because \frac{\partial X_{il}}{\partial X_{st}} = \delta_{is} \delta_{lt})\\ &= \sum_{i, j} \mathrm{d}Z_{ij} (\delta_{is} Y_{tj}) \quad (\because \sum_l \delta_{lt}Y_{lj} = Y_{tj}) \\ &= \sum_{j} \mathrm{d}Z_{sj} Y_{tj} \quad (\because \sum_i \mathrm{d}Z_{ij} \delta_{is} = \mathrm{d}Z_{sj}) \\ &= \sum_{j} \mathrm{d}Z_{sj} (Y^T)_{jt} \\ \end{aligned}

各成分について上式が成り立つことから、以下の行列表現が導出できます。

\mathrm{d}\mathbf{X} = \mathrm{d}\mathbf{Z}\, \mathbf{Y}^T

\mathbf{Y} に関する勾配

同様に成分 Y_{st} ごとに考えます。

\begin{aligned} \mathrm{d}Y_{st} &= \sum_{i, j} \mathrm{d}Z_{ij} \frac{\partial Z_{ij}}{\partial Y_{st}} \\ &= \sum_{i, j} \mathrm{d}Z_{ij} \frac{\partial}{\partial Y_{st}} \left( \sum_{l} X_{il} Y_{lj} \right) \\ &= \sum_{i, j} \mathrm{d}Z_{ij} \sum_{l} X_{il} \left( \frac{\partial Y_{lj}}{\partial Y_{st}} \right) \\ &= \sum_{i, j} \mathrm{d}Z_{ij} \sum_{l} X_{il} (\delta_{ls} \delta_{jt}) \\ &= \sum_{i, j} \mathrm{d}Z_{ij} (X_{is} \delta_{jt}) \\ &= \sum_{i} \mathrm{d}Z_{it} X_{is} \\ &= \sum_{i} (\mathbf{X}^T)_{si} \mathrm{d}Z_{it} \\ \end{aligned}

各成分について上式が成り立つことから、以下の行列表現が導出できます。

\mathrm{d}{Y} = \mathbf{X}^T \mathrm{d}\mathbf{Z}

覚え方

毎回導出するのも少々面倒なので、簡単な覚え方を考えてみました。

記号的な置き換え

\mathbf{Z} = \mathbf{X}\mathbf{Y} に対して \mathbf{Z} \to \mathrm{d}\mathbf{X}, \mathbf{X} \to \mathrm{d}\mathbf{Z}, \mathbf{Y} \to \mathbf{Y}^T のように機械的に置きかえます。
すると、以下が得られます。

\mathrm{d}\mathbf{X} = \mathrm{d}\mathbf{Z}\, \mathbf{Y}^T

同様に、 \mathbf{Z} \to \mathrm{d}\mathbf{Y}, \mathbf{Y} \to \mathrm{d}\mathbf{Z}, \mathbf{X} \to \mathbf{X}^T のように機械的に置きかえます。

\mathrm{d}\mathbf{Y} = \mathbf{X}^T\mathrm{d}\mathbf{Z}

覚えるポイントとしては以下です。

  • 勾配を考えたい入力と出力 \mathbf{Z} を逆転させ、両方 \mathrm{d} をつける。
  • もう一方の入力は転置 (\cdot)^T をとる。

行列の形状合わせ

行列の形状が合うように辻褄合わせをする方法を覚えておくのも役立ちそうです。

\mathbf{X}^T\mathbf{X}\mathbf{Y}\mathbf{Y}^T は一般的に正方行列です。正方行列は右や左から掛けても行列の形状を変化させません。
\mathbf{Z} = \mathbf{X}\mathbf{Y} に対して左右から \mathbf{X}^T\mathbf{Y}^T を掛けることで

\begin{aligned} \mathbf{Z}\mathbf{Y}^T &= \mathbf{X} (\mathbf{Y}\mathbf{Y}^T) \\ \mathbf{X}^T\mathbf{Z} &= (\mathbf{X}^T\mathbf{X}) \mathbf{Y} \end{aligned}

が得られ、 \mathbf{Z}\mathbf{Y}^T\mathbf{X}\mathbf{X}^T\mathbf{Z}\mathbf{Y} は同じ形状とわかります。
\mathbf{X}\mathrm{d}\mathbf{X} などは同じ形状ですから、以下は少なくとも行列の形状としては合っていることがわかります。

\begin{aligned} \mathrm{d}\mathbf{Z}\,\mathbf{Y}^T = \mathrm{d}\mathbf{X} \\ \mathbf{X}^T\mathrm{d}\mathbf{Z} = \mathrm{d}\mathbf{Y} \\ \end{aligned}
脚注
  1. 正直なところ、筆者はこの記法が一般的なのかどうかよくわかっていません。FlashAttention-2 を読んでいて見かけました。表記が簡潔になるのでここでも使用します ↩︎

GitHubで編集を提案

Discussion