🔥

【誤差逆伝播法 -Back Propagation】計算グラフと自動微分

に公開

はじめに

誤差逆伝播法(Back Propagation)は深層学習の基礎的な概念ですが、Pytorch・TensorFlowでは自動化されており普段の業務において意識することは少ないです。下記書籍を拝読中に、丁寧な説明を目にしたため、復習を兼ねて読書メモを残します。

Containts

手計算で求める例

以下の関数に関して、(x, y, z) = (-2, 5, -4)における勾配を計算します。

f(x, y, z) = (x + y)z

この関数の偏微分係数は簡単な手計算により

\frac{\partial f}{\partial x} = z
\frac{\partial f}{\partial y} = z
\frac{\partial f}{\partial z} = x + y

であるから、これらに(x, y, z) = (-2, 5, -4)を代入することで

\bigg(\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}, \frac{\partial f}{\partial z}\bigg) = (-4, -4, 3)

という勾配が得られます。この例では関数が単純であるため、上記の計算で偏微分係数を求め、(x, y, z) = (-2, 5, -4)における勾配を得ることができます。

計算グラフで機械的に求める例

関数f(x, y, z) = (x + y)zを計算グラフにすると以下のような図になります(書籍を参考にdrawioにて写経)。

computation graph

計算手続きは以下です。

  1. xyの和でq(x, y) = x + y が計算される
  2. q(x, y)=zの積で関数値f(x, y, z)=(x + y)zが計算される

ここで (x, y, z) = (-2, 5, -4)のときに\partial f/\partial x の値を求めることを考えます。

合成関数の微分に関する連鎖律(chain rule)により,

\frac{\partial f}{\partial x}=\frac{\partial f}{\partial q}\frac{\partial q}{\partial x}

であるから\partial f/\partial q\partial q/\partial xの値がわかれば、\partial f/\partial xの値が求まります。

q(x, y) = x + yであるため

\frac{\partial q}{\partial x} = 1 
\frac{\partial q}{\partial y} = 1

次に\partial f/\partial qの値です。f(x, y, z)の値はq(x, y)の値から掛け算ノードを通じて計算されるので、掛け算ノードにおける計算を考えます。

g(u, v)=uvに対して

\frac{\partial g}{\partial u}=v, ~~~\frac{\partial g}{\partial v}=u

であることから、掛け算ノードを通過すると、ある入力関数に関する偏微分係数は、もう一方の入力変数の値そのものになります。したがって\partial f/\partial qの値は以下となります。

\frac{\partial f}{\partial q}=z=-4

最後に、目的である\partial f/\partial x の値は連鎖律より

\frac{\partial f}{\partial x}=\frac{\partial f}{\partial q}\frac{\partial q}{\partial x}=-4 \times1=-4

となり、xの偏微分係数を求められます。

以上の計算は、入力変数から関数値が計算されていくプロセスを複数の関数の合成関数として捉え、その合成関数を構成する個々の関数に関する偏微分係数の値をかけ合わせていくことにより元々の関数の勾配を求めます。この計算は計算グラフの複雑さにかかわらず、次の計算に手順で機械的に実行できます。

  1. 前向き計算 :入力から出力に向かって個々の関数値を順次計算する
  2. 後ろ向き計算:出力から入力に向かって偏微分計算を順次計算する

後ろ向き計算を行うには、最初に前向き計算(foward computation)によって関数値をすべて計算します。理由は各関数の偏微分係数の値を計算する際にその引数の値が必要になるからです。これは図におけるエッジ上の値を指します。これらの値を計算した後、
後ろ向き計算(backward computation)によって出力から出発して入力へと逆順に偏微分して計算していくことで、すべてのノードに関する出力値の偏微分係数を求めることができます。

上図の計算グラフにおける後ろ向き計算は次のように進行します。

  1. \partial f/\partial fの値は自明に1
  2. \partial f/\partial qの値はzの値を参照することで-4
  3. \partial f/\partial zの値はq(x, y)の値を参照することで3
  4. \partial f/\partial xの値は\partial f/\partial q = 1 とすでに得られている\partial f/\partial q = -4より-4 \times 1 = -4

Reference

  1. 自然言語処理の基礎 p.55-60

さらに理解を深めるには

Pytrochのtutorialが非常に参考になりました。

こちらは本記事では踏み込まなかったヤコビアン行列など複雑な計算のグラフについて触れています。

Pytorch Turorials > Deep Learning with Pytorch: A 60 Minite Blitz > A Gentle Introduction to torch.autograd

Discussion