🦩

【PyTorch】loss(損失関数)解説

2024/04/07に公開

1. loss(損失関数)とは

lossは、モデルの出力と正解データを比較し、その差異を計算する関数です。
このlossの出力が小さくなるようにモデルはパラメータを更新します。

2. 定義

PyTorchはnnモジュールで様々な損失関数を提供しています。

有名なところでは
nn.L1Loss(): 平均絶対誤差(正解データとの差の平均)
nn.MSELoss(): 平均二乗誤差(正解データとの差の二乗平均)
nn.CrossEntropyLoss(): クロスエントロピー誤差(分類問題の正解率)
nn.KLDivLoss: KL距離(分布間の誤差)

などがあります。

これらをそのまま使用する場合、以下のように損失関数を定義できます。
・損失関数の定義

import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

2.1 カスタム損失関数

PyTorchではカスタム損失関数を定義することができます。
・例

import torch
import torch.nn as nn

class MyMSELoss(nn.Module):
    def __init__(self):
        super(MyMSELoss, self).__init__()
    
    def forward(self, input, target):
        # 損失の計算:(input - target)の二乗の平均
        return torch.mean((input - target) ** 2)

My_loss_fc = MyMSELoss()

# 訓練ループの中で
for input, target in dataloader:
    optimizer.zero_grad()
    output = model(input)
    loss = My_loss_fc(output, target)
    loss.backward() # 逆伝播は、データの最終出力(Tensor)に対して行う
    optimizer.step()

注意点は以下の通りです。

  1. initでsuperによる親クラスメソッドの呼び出しを行う必要がある
  2. forwardで損失を計算し、損失を返す
  3. forward内の計算は全てTensorで行う必要がある(逆伝播のため)
  4. forwardの返り値は、基本的に単一のTensorが期待されます(backwardが単一のTensorを期待するため)。複数の出力は、total_loss = loss1 + loss2のようにまとめる必要があります。

3. 使い方

lossは以下の流れで利用されます。

  1. モデルが順伝播によって推論値を出力
  2. lossで、推論値と正解データを引数にとって損失を計算
  3. 逆伝播でパラメータの更新量を求める
  4. optimizerでパラメータを更新
optimizer.zero_grad()  # オプティマイザの勾配をリセット
output = model(input)  # モデルの出力を計算
loss = loss_fn(output, target)  # 損失を計算
loss.backward()        # バックプロパゲーションを実行
optimizer.step()       # オプティマイザを使用してパラメータを更新

loss.backward() によって、各パラメータの勾配は各パラメータに紐づけられます。
optimizerは、引数として受け取る model.parameters() から、この各パラメータに紐付けられた勾配を利用して各パラメータの更新量を計算し、optimizer.step()でパラメータを調整します。

4. まとめ

今回はPyTorchのlossについて解説しました。

Discussion