🤖

行列積の勾配を誤差逆伝播法により求める

2021/08/11に公開約4,500字

ゼロから作るDeep Learning――Pythonで学ぶディープラーニングの理論と実装の行列積計算の誤差逆伝播法による勾配計算についての記述個人的に分かりづらく、また途中計算が省略されていたので自分への備忘録がてらこちらで途中計算を極力省かず説明していく。

誤差逆伝播法とは

誤差逆伝播法(backpropagation)は深層学習においてパラメタの更新をするために必要な勾配の計算を高速に行うためのアルゴリズムである。
微分(勾配)を計算するにあたり近似的に計算する方法(数値微分)ではCPUにとってハードな計算が必要であり学習の効率が非常に悪くなる。その問題の解決として現在の深層学習において広く用いられているのが誤差逆伝播法であり解析的な微分値を高速に計算することが可能である。

計算ノード

次のようなベクトル\bm{X}(入力値)および行列\bm{W}(重み)の行列積を誤差逆伝播法で計算することを考える。

本記事ではゼロから本に習いベクトルは行ベクトルと考え、\bm{X}(1 \times n)の行ベクトル、\bm{W}(n \times m)の行列とする。この時出力(\bm{X}\cdot \bm{W}) = \bm{Y}(1 \times m)の行ベクトルとなる。
まとめると以下のようになり、

\begin{align} &\bm{X} = \begin{pmatrix} x_1 & x_2 & \dots & x_n \end{pmatrix}\\ &\bm{W} = \begin{pmatrix} w_{11} & w_{12} & \dots & w_{1m} \\ w_{21} & w_{22} & \dots & w_{2m} \\ \vdots & \vdots & \ddots & \vdots \\ w_{n1} & w_{n2} & \dots & w_{nm} \end{pmatrix}\\ &\bm{Y} = \begin{pmatrix} y_1 & y_2 & \dots & y_m \end{pmatrix} \end{align}

この計算ノードの勾配を誤差逆伝播法により算出すると次のようになる。

\begin{align} \frac{\partial L}{\partial \bm{W}} = \bm{X}^\mathsf{T} \cdot \frac{\partial L}{\partial \bm{Y}}\\ \frac{\partial L}{\partial \bm{X}} = \frac{\partial L}{\partial \bm{Y}} \bm{W}^\mathsf{T} \end{align}

これから上の事実を導いていく。

前提

行列計算に関して

出力\bm{Y}の成分y_iは行列の積演算なので次のように計算できる。

\begin{align} y_i = \sum_{l = 1}^{m} w_{li}x_l \end{align}

偏微分に関して

n個の変数y_1, y_2, \dots, y_n(ただしy_iはそれぞれx_1, x_2, \dots, x_nについての関数とする)を持つ関数fx_iで偏微分することを考えると結果は偏微分の連鎖率より以下のようになる。

\begin{align} \frac{\partial f}{\partial x_i} = \sum_{j = 1}^{n} \frac{\partial f}{\partial y_j} \frac{\partial y_j}{\partial x_i} \end{align}

勾配の計算

\frac{\partial L}{\partial \bm{W}}を求める

まず各成分\frac{\partial L}{\partial w_{ij}}を求める。

\begin{align} \frac{\partial L}{\partial w_{ij}} &= \sum_{k = 1}^{m} \frac{\partial L}{\partial y_k} \frac{\partial y_k}{\partial w_{ij}}\\ &= \frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial w_{ij}}\\ &= \frac{\partial L}{\partial y_j} x_i \end{align}

計算を詳しく追っていくと、まず(8)の等号は偏微分の連鎖率(7)より成り立つ。次に(9),(10)の等号は(6)よりw_{ij}を含むのはjのみであり他の項は偏微分により消えるので成り立つ。

この結果を元に\frac{\partial L}{\partial \bm{W}}を計算すると

\begin{align} \frac{\partial L}{\partial \bm{W}} &= \begin{pmatrix} \frac{\partial L}{\partial y_1} x_1 & \frac{\partial L}{\partial y_2} x_1 & \dots & \frac{\partial L}{\partial y_m} x_1 \\ \frac{\partial L}{\partial y_1} x_2 & \frac{\partial L}{\partial y_2} x_2 & \dots & \frac{\partial L}{\partial y_m} x_2 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial y_1} x_n & \frac{\partial L}{\partial y_2} x_n & \dots & \frac{\partial L}{\partial y_m} x_n \end{pmatrix}\\ &= \begin{pmatrix} x_1 \\ x_2 \\ \vdots \\ x_n \end{pmatrix} \cdot \begin{pmatrix} \frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \dots & \frac{\partial L}{\partial y_m} \end{pmatrix}\\ &= \bm{X}^\mathsf{T} \cdot \frac{\partial L}{\partial \bm{Y}} \end{align}

\frac{\partial L}{\partial \bm{X}}を求める

まず各成分\frac{\partial L}{\partial x_{i}}を求める。

\begin{align} \frac{\partial L}{\partial x_i} &= \sum_{j = 1}^{m} \frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial x_i}\\ &= \sum_{j = 1}^{m} \frac{\partial L}{\partial y_j} w_{ij} \end{align}

計算を詳しく追っていくと、まず(14)の等号は先ほどと同様に偏微分の連鎖律により成り立つ。次に(15)の等号もまた先ほどと同様に(6)の式により算出される。
この結果を元に\frac{\partial L}{\partial \bm{X}}を計算すると

\begin{align} \frac{\partial L}{\partial \bm{X}} &= \begin{pmatrix} \sum_{j = 1}^{m} \frac{\partial L}{\partial y_j} w_{1j} & \sum_{j = 1}^{m} \frac{\partial L}{\partial y_j} w_{2j} & \dots & \sum_{j = 1}^{m} \frac{\partial L}{\partial y_j} w_{nj} \end{pmatrix}\\ &= \begin{pmatrix} \frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \dots & \frac{\partial L}{\partial y_m} \end{pmatrix} \cdot \begin{pmatrix} w_{11} & w_{21} & \dots & w_{m1} \\ w_{12} & w_{22} & \dots & w_{m2} \\ \vdots & \vdots & \ddots & \vdots \\ w_{1n} & w_{2n} & \dots & w_{mn} \end{pmatrix}\\ &= \frac{\partial L}{\partial \bm{Y}} \bm{W}^\mathsf{T} \end{align}

Discussion

ログインするとコメントできます