🔙

計算グラフと逆伝播法

に公開

計算グラフ (Computational Graph) とは

平たく言えば 「単純な関数をつなげたもの」

  • Mul : 乗算 ( \times )
  • Add : 加算 ( + )
  • PowN : 冪乗 ( ^N )

ここでいう関数とは、上の MulAddPowN のような最小単位の演算。

  • 各ノードは 複数の入力値 をとり得る。
  • 各ノードの 出力は単一

逆伝播

上のように順方向に計算グラフを解くことを 順伝播 (forward propagation) と呼ぶ。これに対して、逆方向に解くことを 逆伝播 (backward propagation) と呼ぶ。

入力値が複数ある場合は、入力値それぞれに対して計算を行う(例:出力→入力Aについて逆伝播の計算をする)。

逆伝播で求まること

まず、「逆伝播するとどうなるか」と考えるのはナンセンス。そうではなく「計算グラフを利用してある値を求める方法」として「逆から解く必要がある」ので「逆伝播」と呼ぶと思った方がよい。

何が求まるかというと、入力それぞれについて 計算グラフ全体での増加率 が求まる。

つまり、計算グラフ全体を G(\text{入力A},\text{入力B},\text{入力C},\cdots) とした場合、 \displaystyle \frac{\partial G}{\partial \text{任意の入力}} が求まる。

逆伝播の計算方法

入力Aにおける計算グラフ G 全体(入力A→出力への一本道)の増加率は、各ノードの増加率を掛け合わせれば求まる。

\frac{\partial G}{\partial \text{入力}A} = \frac{\partial f_1}{\partial \text{入力}_1} \times \frac{\partial f_2}{\partial \text{入力}_2} \times \frac{\partial f_3}{\partial \text{入力}_3} \cdots

f_n はノードnの関数、 \text{入力}_n はノードnの入力。
※ 「入力」「出力」は順方向におけるものを指すので注意。
※ 逆伝播が成り立つ理由は「連鎖律」で説明される。計算グラフは複数のノード(関数)で構成される、つまり合成関数であり、その全体の微分値(=変化率)は構成する関数(ノード)の微分値の積で求められるという定理。「変化率 \times 変化率 = 全体の変化率」という当たり前のことを言っているだけ。

単純化のためにノード n の増加率 \displaystyle \frac{\partial f_n}{\partial \text{入力}_n} = \Delta_n とする。

\Delta_G = \Delta_1 \times \Delta_2 \times \Delta_3 \cdots

このとき、あるノード n 以降 の増加率 ( \Delta_{n^+} と書く) は以下のように書ける。

\Delta_{n^+} = \Delta_n \times \Delta_{n+1} \times \Delta_{n+2} \cdots

つまり、

  1. 最後のノードの増加率を求める
  2. 一つ前のノードの増加率 \times 最後のノードの増加率を求める
  3. ...最初のノードまで遡る

という具合に後ろ→前に(逆伝播)求めていけば、グラフ全体の増加率が求まる。

例えば以下の計算グラフの場合(入力A以外のルートは省略)

\begin{align*} \Delta_G &= \Delta_1 \times \Delta_2 \times \Delta_3 \\[1em] \Delta_{3^+} &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em] \Delta_{2^+} &= \frac{\partial f_2}{\partial \text{入力}_2} \times \Delta_{3^+} \\[1em] \Delta_{1^+} &= \frac{\partial f_1}{\partial \text{入力}_1} \times \Delta_{2^+} \end{align*}

逆伝播の例題

最初の計算グラフについて、逆伝播を用いて入力Aの増加率を求めてみよう。

  • Mul : 乗算 ( \times )
  • Add : 加算 ( + )
  • PowN : 冪乗 ( ^N )

以下のような順で解いていく。

\begin{align*} \Delta_G &= \Delta_1 \times \Delta_2 \times \Delta_3 \\[1em] \Delta_{3^+} &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em] \Delta_{2^+} &= \frac{\partial f_2}{\partial \text{入力}_2} \times \Delta_{3^+} \\[1em] \Delta_{1^+} &= \frac{\partial f_1}{\partial \text{入力}_1} \times \Delta_{2^+} \end{align*}

ノード3 (冪乗)

\begin{cases} \Delta_3 = \frac{\partial f_3}{\partial \text{入力}_3} \\[1em] f_3 (\text{入力}_3) = (\text{入力}_3)^2 \end{cases}
\begin{align*} \Delta_3 &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em] &= \text{入力}_3 \times 2 \end{align*}

ノード2 (加算)

\begin{cases} \Delta_2 = \frac{\partial f_2}{\partial \text{入力}_2} \\[1em] f_2 (\text{入力}_{2}, \text{入力}_{C}) = \text{入力}_{2} + \text{入力}_{C} \end{cases}
\begin{align*} \Delta_2 &= \frac{\partial f_2}{\partial \text{入力}_2} \\[1em] &= 1 \\[1em] \Delta_{2^+} &= \Delta_2 \times \Delta_3 \\[1em] &= 1 \times (\text{入力}_3 \times 2) \end{align*}

ノード1 (乗算)

\begin{cases} \Delta_1 = \frac{\partial f_1}{\partial \text{入力}_1} \\[1em] f_1 (\text{入力}_{1}, \text{入力}_{B}) = \text{入力}_{1} \times \text{入力}_{B} \end{cases}

\text{入力}_{1} = \text{入力}_{A} だが、文脈上 \text{入力}_{1} としている。

\begin{align*} \Delta_1 &= \frac{\partial f_1}{\partial \text{入力}_1} \\[1em] &= \text{入力}_{B} \\[1em] \Delta_{1^+} &= \Delta_1 \times \Delta_{2^+} \\[1em] &= \text{入力}_{B} \times 1 \times (\text{入力}_3 \times 2) \end{align*}

求まったけど

ということで \displaystyle \frac{\partial G}{\partial \text{入力}_{A}} = 2 \text{入力}_{B} \text{入力}_3 と求まった。

この \text{入力}_3f_2f_1 で表現できるが、今回はそこまでしない。

なぜしないかというと、 計算グラフの逆伝播の利用シーンではこのように代数的に(変数を当てたまま)解くことをしない から。

薄々勘づかれているだろうが、そもそも、上のように代数的に解くのであれば逆伝播=逆から解いていく必要は別にない。もっというと別に計算グラフ的な考え方をする必要はなく、単に合成関数の微分をすればいいだけ。

逆伝播は、実際の値をそのまま使って計算していく

計算グラフの逆伝播の利用シーンでは、順伝播の際の実際の値を記録しておき、全ての変数をその値で固定して計算する。こうすることで、コンピュータで微分の計算が容易にできる。

  • コンピュータで代数的に数式を扱うのはハード...どういうアルゴリズムで式変形するの?
  • その解決策として数値微分があるが、ニューラルネットワークのような場合は逆伝播法の方が効率が良い。
    • ノード単位の計算結果を再利用できるので。
    • ニューラルネットワークの形が計算グラフにそのまま当てはまるので非常に効率が良い。
  • 逆伝播法であれば解析的(↔︎数値微分による近似)であり、より正確。

逆伝播の例題 (順伝播を記録)

ということで、実際の利用シーンと同じように順伝播の値を記録し、その値を利用して解く。

  • Mul : 乗算 ( \times )
  • Add : 加算 ( + )
  • PowN : 冪乗 ( ^N )

つまり、 G(\text{入力A},\text{入力B},\text{入力C}) について \displaystyle \frac{\partial G}{\partial \text{入力A}} \Big|_{A=10,\,B=1.1,\,C=3} を求める。

ノード3 (冪乗)

\begin{cases} \Delta_3 = \frac{\partial f_3}{\partial \text{入力}_3} \\[1em] f_3 (\text{入力}_3) = (\text{入力}_3)^2 \\[1em] \text{入力}_3 = 14 \end{cases}
\begin{align*} \Delta_3 &= \frac{\partial f_3}{\partial \text{入力}_3} \\[1em] &= \text{入力}_3 \times 2 \\[1em] &= 28 \end{align*}

ノード2 (加算)

\begin{cases} \Delta_2 = \frac{\partial f_2}{\partial \text{入力}_2} \\[1em] f_2 (\text{入力}_{2}, \text{入力}_{C}) = \text{入力}_{2} + \text{入力}_{C} \\[1em] \text{入力}_2 = 11 \\[1em] \text{入力}_C = 3 \end{cases}
\begin{align*} \Delta_2 &= \frac{\partial f_2}{\partial \text{入力}_2} \\[1em] &= 1 \\[1em] \Delta_{2^+} &= \Delta_2 \times \Delta_3 \\[1em] &= 1 \times 28 \end{align*}

ノード1 (乗算)

\begin{cases} \Delta_1 = \frac{\partial f_1}{\partial \text{入力}_1} \\[1em] f_1 (\text{入力}_{1}, \text{入力}_{B}) = \text{入力}_{1} \times \text{入力}_{B} \\[1em] \text{入力}_1 = 10 \\[1em] \text{入力}_B = 1.1 \end{cases}

\text{入力}_{1} = \text{入力}_{A} だが、文脈上 \text{入力}_{1} としている。

\begin{align*} \Delta_1 &= \frac{\partial f_1}{\partial \text{入力}_1} \\[1em] &= \text{入力}_{B} \\[1em] &= 1.1 \\[1em] \Delta_{1^+} &= \Delta_1 \times \Delta_{2^+} \\[1em] &= \text{入力}_{B} \times 1 \times (\text{入力}_3 \times 2) \\[1em] &= 1.1 \times 1 \times 28 \\[1em] &= 30.8 \end{align*}

求まった

求まった。コンピュータアルゴリズムの実装においては、実装に誤りがないか数値微分でテストするらしい。

数値微分してみる

\begin{align*} G(A,B,C) &= (A \times B + C)^2 \\[1em] \frac{G(A+h)-G(A-h)}{2h} &= \frac{G(A+10^{-4})-G(A-10^{-4})}{2 \times 10^{-4}} \\[1em] \end{align*}

A=10, B=1.1, C=3 として

\begin{align*} \frac{G(10+10^{-4})-G(10-10^{-4})}{2 \times 10^{-4}} &= \frac{((10+10^{-4}) \times 1.1 + 3)^2-((10-10^{-4}) \times 1.1 + 3)^2}{2 \times 10^{-4}} \\[1em] &= 30.805 \end{align*}

近い値になった。

Discussion