📉

PyTorch の zero_grad() は何をしているか

2023/02/26に公開

最近研究で機械学習を使い始めたので、 PyTorch に入門してコードを書いているのですが、モデルの学習時における torch.optim.Optimizer.zero_grad() では何をしているのか理解できなかったので調べてみることにしました。

公式ドキュメントによると

https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html

まずは、PyTorch の公式ドキュメントを見てみましょう。

TORCH.OPTIM.OPTIMIZER.ZERO_GRAD
-- Sets the gradients of all optimized torch.Tensors to zero.

関数の説明を見に行ってみると、一行だけ説明文が書かれており、「すべての最適化対象の torch.Tensor の勾配を 0 にセットする」とありました。つまり、zero_grad() では、勾配降下法の計算を正しく行うためにモデル内のパラメータの勾配を初期化しているということですね。しかし、なぜこのような処理を明示的に行う必要があるのでしょうか。

Stack Overflow によると

https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch

同様の質問をしている Stack Overflow の回答が見つかったので、ご紹介します。意訳しているため、間違っている部分があるかもしれません。ご指摘ください。

PyTorch での訓練フェイズにおける毎回のミニバッチ計算では、重みやバイアスの更新といった誤差逆伝播計算を始める前に、通常明示的に勾配を 0 に設定する必要があります。なぜなら、 PyTorch が後続の逆伝播で勾配を累積するからです。こうした累積処理は、 RNN の訓練時や複数のミニバッチに渡って合計された損失関数の勾配を計算したいときに便利なため用いられています。loss.backward() が呼び出される度に毎回、勾配の累積処理が行われているのです。
このような理由により、訓練のループを開始する際に、パラメータを正しく更新するために。勾配を 0 にすべきということです。そうしなければ、勾配は、すでにモデルのパラメータを更新するために使用した古い勾配と新しく計算された勾配の足し合わせとなります。そのため、勾配降下の方向と異なる方向を指すことになってしまいます。

つまり、 PyTorch では、RNN の訓練やミニバッチの計算を効率的にするために、前回計算した勾配が残っている場合、前回計算した勾配と今回計算した勾配を足し合わせた勾配を利用するようにしているとのことです。

上記の回答では、シンプルな疑似コードを交えて説明してくださっていましたので、今回はそのコードを実際に実行して動作を確認してみました(元のコードは所々省略されていたので、実際に動作するように少し改変しています)。

実際に動かしながら挙動を確認する

以下のようなコードを使って、どのような挙動となっているのかを確認しました。

import torch
from torch.autograd import Variable
import torch.optim as optim
import re

def linear_model(x, W, b):
    return torch.matmul(x, W) + b


data = torch.Tensor([[1, 1], [3, 3]])
targets = torch.Tensor([10, 100])

W = Variable(torch.randn(2, 1), requires_grad=True)
b = Variable(torch.randn(1), requires_grad=True)

criterion = torch.nn.MSELoss()
optimizer = optim.Adam([W, b])

for i, (sample, target) in enumerate(zip(data, targets)):
    print(f"loop: {i + 1}")
    format_params("init", W, b)

    optimizer.zero_grad()
    format_params("optimizer.zero_grad()", W, b)

    output = linear_model(sample, W, b)
    loss = criterion(output.squeeze(), target)
    loss.backward()
    format_params("loss.backward()", W, b)

    optimizer.step()
    format_params("optimizer.step()", W, b)
    
 
 def format_params(timing, W, b):
    """見やすいように出力をフォーマット"""
    torch.set_printoptions(precision=4)
    print(f"--- {timing}")
    print(re.sub("\n|\t|       |, requires_grad=True|", "", f"    W = {W}, W_grad = {W.grad}"))
    print(re.sub("\n|\t|       |, requires_grad=True|", "", f"    b = {b}, b_grad = {b.grad}"))
    print("")

上記のコードの出力結果は次のようになります。

loop: 1
--- init
    W = tensor([[ 1.4219], [-0.2314]]), W_grad = None
    b = tensor([0.3739]), b_grad = None

--- optimizer.zero_grad()
    W = tensor([[ 1.4219], [-0.2314]]), W_grad = None
    b = tensor([0.3739]), b_grad = None

--- loss.backward()
    W = tensor([[ 1.4219], [-0.2314]]), W_grad = tensor([[-16.8712], [-16.8712]])
    b = tensor([0.3739]), b_grad = tensor([-16.8712])

--- optimizer.step()
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[-16.8712], [-16.8712]])
    b = tensor([0.3749]), b_grad = tensor([-16.8712])

loop: 2
--- init
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[-16.8712], [-16.8712]])
    b = tensor([0.3749]), b_grad = tensor([-16.8712])

--- optimizer.zero_grad()
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[0.], [0.]])
    b = tensor([0.3749]), b_grad = tensor([0.])

--- loss.backward()
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[-576.2858], [-576.2858]])
    b = tensor([0.3749]), b_grad = tensor([-192.0953])

--- optimizer.step()
    W = tensor([[ 1.4237], [-0.2297]]), W_grad = tensor([[-576.2858], [-576.2858]])
    b = tensor([0.3757]), b_grad = tensor([-192.0953])

まず1回目のループでは、パラメータ Wb の勾配である W.gradb.grad には None が入っており、 optimizer.zero_grad() の後もその値が変わっていないことがわかります。loss.backward() で誤差逆伝播計算が行われると、 W.gradb.grad に何らかの値が入っています。その後、 optimizer.step() でパラメータ Wb の値が更新されます。

続いて2回目のループでは、パラメータとその勾配の値は1回目のループの終了時点から変化がありません。前回計算され更新されたパラメータと勾配の値がそのまま残っていることが確認できます。ここで optimizer.zero_grad() を行うと、W.grad および b.grad の値が 0 になっていることがわかります。このように、 optimizer.zero_grad() を実行することでパラメータの勾配を 0 に設定することができます。

比較として、optimizer.zero_grad() を用いなかった場合の出力結果は次のようになります。

loop: 1
--- init
    W = tensor([[ 1.4219], [-0.2314]]), W_grad = None
    b = tensor([0.3739]), b_grad = None

--- loss.backward()
    W = tensor([[ 1.4219], [-0.2314]]), W_grad = tensor([[-16.8712], [-16.8712]])
    b = tensor([0.3739]), b_grad = tensor([-16.8712])

--- optimizer.step()
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[-16.8712], [-16.8712]])
    b = tensor([0.3749]), b_grad = tensor([-16.8712])

loop: 2
--- init
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[-16.8712], [-16.8712]])
    b = tensor([0.3749]), b_grad = tensor([-16.8712])

--- loss.backward()
    W = tensor([[ 1.4229], [-0.2304]]), W_grad = tensor([[-593.1570], [-593.1570]])
    b = tensor([0.3749]), b_grad = tensor([-208.9665])

--- optimizer.step()
    W = tensor([[ 1.4237], [-0.2297]]), W_grad = tensor([[-593.1570], [-593.1570]])
    b = tensor([0.3757]), b_grad = tensor([-208.9665])

勾配が 0 に初期化されていないため、 zero_grad() を用いていないときに比べて勾配の値が小さくなっています。これは、2回のループ内でそれぞれ計算された勾配の値が足し合わされているためです。

ご覧いただいたように、 optimizer.zero_grad()を呼ばないと、勾配の値が正しく計算されずに、モデルの学習が適切に行われない可能性があるため、注意しましょう。

Discussion