計算グラフ (Computational Graph) とは
平たく言えば 「単純な関数をつなげたもの」 。
-
Mul : 乗算 ( \times )
-
Add : 加算 ( + )
-
PowN : 冪乗 ( ^N )
ここでいう関数とは、上の Mul や Add や PowN のような最小単位の演算。
- 各ノードは 複数の入力値 をとり得る。
- 各ノードの 出力は単一 。
逆伝播
上のように順方向に計算グラフを解くことを 順伝播 (forward propagation) と呼ぶ。これに対して、逆方向に解くことを 逆伝播 (backward propagation) と呼ぶ。
入力値が複数ある場合は、入力値それぞれに対して計算を行う(例:出力→入力Aについて逆伝播の計算をする)。
逆伝播で求まること
まず、「逆伝播するとどうなるか」と考えるのはナンセンス。そうではなく「計算グラフを利用してある値を求める方法」として「逆から解く必要がある」ので「逆伝播」と呼ぶと思った方がよい。
何が求まるかというと、入力それぞれについて 計算グラフ全体での増加率 が求まる。
つまり、計算グラフ全体を G(\text{入力A},\text{入力B},\text{入力C},\cdots) とした場合、 \displaystyle \frac{\partial G}{\partial \text{任意の入力}} が求まる。
逆伝播の計算方法
入力Aにおける計算グラフ G 全体(入力A→出力への一本道)の増加率は、各ノードの増加率を掛け合わせれば求まる。
\frac{\partial G}{\partial \text{入力}A} = \frac{\partial f_1}{\partial \text{入力}_1} \times \frac{\partial f_2}{\partial \text{入力}_2} \times \frac{\partial f_3}{\partial \text{入力}_3} \cdots
※ f_n はノードnの関数、 \text{入力}_n はノードnの入力。
※ 「入力」「出力」は順方向におけるものを指すので注意。
※ 逆伝播が成り立つ理由は「連鎖律」で説明される。計算グラフは複数のノード(関数)で構成される、つまり合成関数であり、その全体の微分値(=変化率)は構成する関数(ノード)の微分値の積で求められるという定理。「変化率 \times 変化率 = 全体の変化率」という当たり前のことを言っているだけ。
単純化のためにノード n の増加率 \displaystyle \frac{\partial f_n}{\partial \text{入力}_n} = \Delta_n とする。
\Delta_G = \Delta_1 \times \Delta_2 \times \Delta_3 \cdots
このとき、あるノード n 以降 の増加率 ( \Delta_{n^+} と書く) は以下のように書ける。
\Delta_{n^+} = \Delta_n \times \Delta_{n+1} \times \Delta_{n+2} \cdots
つまり、
- 最後のノードの増加率を求める
- 一つ前のノードの増加率 \times 最後のノードの増加率を求める
- ...最初のノードまで遡る
という具合に後ろ→前に(逆伝播)求めていけば、グラフ全体の増加率が求まる。
例
例えば以下の計算グラフの場合(入力A以外のルートは省略)
\begin{align*}
\Delta_G &= \Delta_1 \times \Delta_2 \times \Delta_3 \\[1em]
\Delta_{3^+} &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em]
\Delta_{2^+} &= \frac{\partial f_2}{\partial \text{入力}_2} \times \Delta_{3^+} \\[1em]
\Delta_{1^+} &= \frac{\partial f_1}{\partial \text{入力}_1} \times \Delta_{2^+}
\end{align*}
逆伝播の例題
最初の計算グラフについて、逆伝播を用いて入力Aの増加率を求めてみよう。
-
Mul : 乗算 ( \times )
-
Add : 加算 ( + )
-
PowN : 冪乗 ( ^N )
以下のような順で解いていく。
\begin{align*}
\Delta_G &= \Delta_1 \times \Delta_2 \times \Delta_3 \\[1em]
\Delta_{3^+} &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em]
\Delta_{2^+} &= \frac{\partial f_2}{\partial \text{入力}_2} \times \Delta_{3^+} \\[1em]
\Delta_{1^+} &= \frac{\partial f_1}{\partial \text{入力}_1} \times \Delta_{2^+}
\end{align*}
ノード3 (冪乗)
\begin{cases}
\Delta_3 = \frac{\partial f_3}{\partial \text{入力}_3} \\[1em]
f_3 (\text{入力}_3) = (\text{入力}_3)^2
\end{cases}
\begin{align*}
\Delta_3 &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em]
&= \text{入力}_3 \times 2
\end{align*}
ノード2 (加算)
\begin{cases}
\Delta_2 = \frac{\partial f_2}{\partial \text{入力}_2} \\[1em]
f_2 (\text{入力}_{2}, \text{入力}_{C}) = \text{入力}_{2} + \text{入力}_{C}
\end{cases}
\begin{align*}
\Delta_2 &= \frac{\partial f_2}{\partial \text{入力}_2} \\[1em]
&= 1 \\[1em]
\Delta_{2^+} &= \Delta_2 \times \Delta_3 \\[1em]
&= 1 \times (\text{入力}_3 \times 2)
\end{align*}
ノード1 (乗算)
\begin{cases}
\Delta_1 = \frac{\partial f_1}{\partial \text{入力}_1} \\[1em]
f_1 (\text{入力}_{1}, \text{入力}_{B}) = \text{入力}_{1} \times \text{入力}_{B}
\end{cases}
※ \text{入力}_{1} = \text{入力}_{A} だが、文脈上 \text{入力}_{1} としている。
\begin{align*}
\Delta_1 &= \frac{\partial f_1}{\partial \text{入力}_1} \\[1em]
&= \text{入力}_{B} \\[1em]
\Delta_{1^+} &= \Delta_1 \times \Delta_{2^+} \\[1em]
&= \text{入力}_{B} \times 1 \times (\text{入力}_3 \times 2)
\end{align*}
求まったけど
ということで \displaystyle \frac{\partial G}{\partial \text{入力}_{A}} = 2 \text{入力}_{B} \text{入力}_3 と求まった。
この \text{入力}_3 は f_2 と f_1 で表現できるが、今回はそこまでしない。
なぜしないかというと、 計算グラフの逆伝播の利用シーンではこのように代数的に(変数を当てたまま)解くことをしない から。
薄々勘づかれているだろうが、そもそも、上のように代数的に解くのであれば逆伝播=逆から解いていく必要は別にない。もっというと別に計算グラフ的な考え方をする必要はなく、単に合成関数の微分をすればいいだけ。
逆伝播は、実際の値をそのまま使って計算していく
計算グラフの逆伝播の利用シーンでは、順伝播の際の実際の値を記録しておき、全ての変数をその値で固定して計算する。こうすることで、コンピュータで微分の計算が容易にできる。
- コンピュータで代数的に数式を扱うのはハード...どういうアルゴリズムで式変形するの?
- その解決策として数値微分があるが、ニューラルネットワークのような場合は逆伝播法の方が効率が良い。
- ノード単位の計算結果を再利用できるので。
- ニューラルネットワークの形が計算グラフにそのまま当てはまるので非常に効率が良い。
- 逆伝播法であれば解析的(↔︎数値微分による近似)であり、より正確。
逆伝播の例題 (順伝播を記録)
ということで、実際の利用シーンと同じように順伝播の値を記録し、その値を利用して解く。
-
Mul : 乗算 ( \times )
-
Add : 加算 ( + )
-
PowN : 冪乗 ( ^N )
つまり、 G(\text{入力A},\text{入力B},\text{入力C}) について \displaystyle \frac{\partial G}{\partial \text{入力A}} \Big|_{A=10,\,B=1.1,\,C=3} を求める。
ノード3 (冪乗)
\begin{cases}
\Delta_3 = \frac{\partial f_3}{\partial \text{入力}_3} \\[1em]
f_3 (\text{入力}_3) = (\text{入力}_3)^2 \\[1em]
\text{入力}_3 = 14
\end{cases}
\begin{align*}
\Delta_3 &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em]
&= \text{入力}_3 \times 2 \\[1em]
&= 28
\end{align*}
ノード2 (加算)
\begin{cases}
\Delta_2 = \frac{\partial f_2}{\partial \text{入力}_2} \\[1em]
f_2 (\text{入力}_{2}, \text{入力}_{C}) = \text{入力}_{2} + \text{入力}_{C} \\[1em]
\text{入力}_2 = 11 \\[1em]
\text{入力}_C = 3
\end{cases}
\begin{align*}
\Delta_2 &= \frac{\partial f_2}{\partial \text{入力}_2} \\[1em]
&= 1 \\[1em]
\Delta_{2^+} &= \Delta_2 \times \Delta_3 \\[1em]
&= 1 \times 28
\end{align*}
ノード1 (乗算)
\begin{cases}
\Delta_1 = \frac{\partial f_1}{\partial \text{入力}_1} \\[1em]
f_1 (\text{入力}_{1}, \text{入力}_{B}) = \text{入力}_{1} \times \text{入力}_{B} \\[1em]
\text{入力}_1 = 10 \\[1em]
\text{入力}_B = 1.1
\end{cases}
※ \text{入力}_{1} = \text{入力}_{A} だが、文脈上 \text{入力}_{1} としている。
\begin{align*}
\Delta_1 &= \frac{\partial f_1}{\partial \text{入力}_1} \\[1em]
&= \text{入力}_{B} \\[1em]
&= 1.1 \\[1em]
\Delta_{1^+} &= \Delta_1 \times \Delta_{2^+} \\[1em]
&= \text{入力}_{B} \times 1 \times (\text{入力}_3 \times 2) \\[1em]
&= 1.1 \times 1 \times 28 \\[1em]
&= 30.8
\end{align*}
求まった
求まった。コンピュータアルゴリズムの実装においては、実装に誤りがないか数値微分でテストするらしい。
数値微分してみる
\begin{align*}
G(A,B,C) &= (A \times B + C)^2 \\[1em]
\frac{G(A+h)-G(A-h)}{2h} &= \frac{G(A+10^{-4})-G(A-10^{-4})}{2 \times 10^{-4}} \\[1em]
\end{align*}
A=10, B=1.1, C=3 として
\begin{align*}
\frac{G(10+10^{-4})-G(10-10^{-4})}{2 \times 10^{-4}} &= \frac{((10+10^{-4}) \times 1.1 + 3)^2-((10-10^{-4}) \times 1.1 + 3)^2}{2 \times 10^{-4}} \\[1em]
&= 30.805
\end{align*}
近い値になった。
Discussion