👏

[nnabla]特定のレイヤーの勾配計算をさせない方法

2022/09/22に公開

はじめに

nnablaにおいて、特定の計算からの勾配を伝搬させたくないという状況を実装した時の備忘録。
例えば、L=w(y)\times L_1(x)としたときに、xについてはすべての勾配を伝搬させたいが、yについて、Lの計算における勾配\frac{\partial L}{\partial w}は計算したくないという状況。

環境

python==3.8.0
nnabla==1.30.0

実装

\displaystyle L = w(y) \times L_1(x)を次の形で実装してみる。

L_1 = x^2 \ ... \ (1) \\ w = 2y \ ... \ (2) \\ L = wL_1 \ ... \ (3)

この時、yについて、(2)の計算では勾配を取得したいが、(3)の計算では勾配を伝搬させたくないという状況を考える。ここで、Lに対するx,yの勾配はそれぞれ次のように計算できる。

\frac{\partial L}{\partial x} = \frac{\partial L}{\partial L_1} \cdot \frac{\partial L_1}{\partial x} = w \cdot 2x \ ... \ (4) \\ \frac{\partial L}{\partial y} = \frac{\partial L}{\partial w} \cdot \frac{\partial w}{\partial y} = L_1 \cdot 2 \ ... \ (5)

この時、(5)における\frac{\partial L}{\partial w}を計算せずに1として勾配をyに伝搬させたい状況を考える。

コードと結果

下記の形で実装した。

import nnabla as nn
x = nn.Variable(need_grad=True)
x.d = 3
y = nn.Variable(need_grad=True)
y.d = 5
L1 = x ** 2
w = 2 * y
L = w * L1
L.forward()
L.backward()

上記のコードを実行すると以下の結果が得られた。

L.d = 90
x.g = 60
y.g = 18

実際、x=3, y=5とした時、(3),(4),(5)から、L=90, \frac{\partial L}{\partial x}=60, \frac{\partial L}{\partial y}=18となる。
ここで、yにおける(3)の勾配のみを計算しない(すなわち、\frac{\partial L}{\partial w}:=1)とするためには、以下のように実装すれば良い。

import nnabla as nn
import nnabla.functions as F
x = nn.Variable(need_grad=True)
x.d = 3
y = nn.Variable(need_grad=True)
y.d = 5
L1 = x ** 2
w = 2 * y
w_i = F.identity(w)
w_i.need_grad = False     # <-- 勾配計算をFalse
L = w_i * L1 + w - w_i    # <-- 勾配が0にならないように工夫
L.forward()
L.backward()

上記のコードを実行すると以下の結果が得られる。
上記のコードでは、11行目において、第1項目でL1の勾配は計算できるようにしつつ、w_iの勾配は10行目で止めているので計算されない。しかし、このままだとwの勾配が0(初期値)のまま伝搬されるので、wを加算し、forwardの計算結果が変わらないようにw_iで引いている。これによって、11行目の計算におけるwの勾配は第2項目で計算される1のみとなる。

L.d = 90
x.g = 60
y.g = 2

実際、(5)において\frac{\partial L}{\partial w}=1とすると、\frac{\partial L}{\partial y}=2となる。

Discussion