【誤差逆伝播法 -Back Propagation】計算グラフと自動微分
はじめに
誤差逆伝播法(Back Propagation)は深層学習の基礎的な概念ですが、Pytorch・TensorFlowでは自動化されており普段の業務において意識することは少ないです。下記書籍を拝読中に、丁寧な説明を目にしたため、復習を兼ねて読書メモを残します。
Containts
手計算で求める例
以下の関数に関して、
この関数の偏微分係数は簡単な手計算により
であるから、これらに
という勾配が得られます。この例では関数が単純であるため、上記の計算で偏微分係数を求め、
計算グラフで機械的に求める例
関数
計算手続きは以下です。
-
とx の和でy が計算されるq(x, y) = x + y -
とq(x, y)= の積で関数値z が計算されるf(x, y, z)=(x + y)z
ここで
合成関数の微分に関する連鎖律(chain rule)により,
であるから
次に
であることから、掛け算ノードを通過すると、ある入力関数に関する偏微分係数は、もう一方の入力変数の値そのものになります。したがって
最後に、目的である
となり、
以上の計算は、入力変数から関数値が計算されていくプロセスを複数の関数の合成関数として捉え、その合成関数を構成する個々の関数に関する偏微分係数の値をかけ合わせていくことにより元々の関数の勾配を求めます。この計算は計算グラフの複雑さにかかわらず、次の計算に手順で機械的に実行できます。
- 前向き計算 :入力から出力に向かって個々の関数値を順次計算する
- 後ろ向き計算:出力から入力に向かって偏微分計算を順次計算する
後ろ向き計算を行うには、最初に前向き計算(foward computation)によって関数値をすべて計算します。理由は各関数の偏微分係数の値を計算する際にその引数の値が必要になるからです。これは図におけるエッジ上の値を指します。これらの値を計算した後、
後ろ向き計算(backward computation)によって出力から出発して入力へと逆順に偏微分して計算していくことで、すべてのノードに関する出力値の偏微分係数を求めることができます。
上図の計算グラフにおける後ろ向き計算は次のように進行します。
-
の値は自明に1\partial f/\partial f -
の値は\partial f/\partial q の値を参照することでz -4 -
の値は\partial f/\partial z の値を参照することでq(x, y) 3 -
の値は\partial f/\partial x とすでに得られている\partial f/\partial q = 1 より\partial f/\partial q = -4 -4 \times 1 = -4
Reference
- 自然言語処理の基礎 p.55-60
さらに理解を深めるには
Pytrochのtutorialが非常に参考になりました。
こちらは本記事では踏み込まなかったヤコビアン行列など複雑な計算のグラフについて触れています。
Discussion