🦩
【PyTorch Method】loss(損失関数)解説
1. loss(損失関数)とは
lossは、モデルの出力と正解データを比較し、その差異を計算する関数です。
このlossの出力が小さくなるようにモデルはパラメータを更新します。
2. 定義
PyTorchはnnモジュールで様々な損失関数を提供しています。
有名なところでは
nn.L1Loss(): 平均絶対誤差(正解データとの差の平均)
nn.MSELoss(): 平均二乗誤差(正解データとの差の二乗平均)
nn.CrossEntropyLoss(): クロスエントロピー誤差(分類問題の正解率)
nn.KLDivLoss: KL距離(分布間の誤差)
などがあります。
これらをそのまま使用する場合、以下のように損失関数を定義できます。
・損失関数の定義
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
2.1 カスタム損失関数
PyTorchではカスタム損失関数を定義することができます。
・例
import torch
import torch.nn as nn
class MyMSELoss(nn.Module):
def __init__(self):
super(MyMSELoss, self).__init__()
def forward(self, input, target):
# 損失の計算:(input - target)の二乗の平均
return torch.mean((input - target) ** 2)
My_loss_fc = MyMSELoss()
# 訓練ループの中で
for input, target in dataloader:
optimizer.zero_grad()
output = model(input)
loss = My_loss_fc(output, target)
loss.backward() # 逆伝播は、データの最終出力(Tensor)に対して行う
optimizer.step()
注意点は以下の通りです。
- initでsuperによる親クラスメソッドの呼び出しを行う必要がある
- forwardで損失を計算し、損失を返す
- forward内の計算は全てTensorで行う必要がある(逆伝播のため)
- forwardの返り値は、基本的に単一のTensorが期待されます(backwardが単一のTensorを期待するため)。複数の出力は、
total_loss = loss1 + loss2
のようにまとめる必要があります。
3. 使い方
lossは以下の流れで利用されます。
- モデルが順伝播によって推論値を出力
- lossで、推論値と正解データを引数にとって損失を計算
- 逆伝播でパラメータの更新量を求める
- optimizerでパラメータを更新
optimizer.zero_grad() # オプティマイザの勾配をリセット
output = model(input) # モデルの出力を計算
loss = loss_fn(output, target) # 損失を計算
loss.backward() # バックプロパゲーションを実行
optimizer.step() # オプティマイザを使用してパラメータを更新
loss.backward() によって、各パラメータの勾配は各パラメータに紐づけられます。
optimizerは、引数として受け取る model.parameters() から、この各パラメータに紐付けられた勾配を利用して各パラメータの更新量を計算し、optimizer.step()でパラメータを調整します。
4. まとめ
今回はPyTorchのlossについて解説しました。
Discussion