🔖

PyTorch Custom Loss with NumPy

2024/02/01に公開

PyTorchでカスタマイズした目的(損失)関数を設定したい

1.はじめに

ディープラーニングの実装において、PyTorch はその柔軟性と効率性で広く利用されています。特に、カスタム損失関数を用いることで、特定の問題に合わせた最適化が可能になります。しかし、このプロセスには、 Numpy との相互作用に関連する一つの重要な落とし穴があります。本記事では、PyTorch の自動微分エンジンと Numpy を用いたカスタム損失関数の作成における問題点とその解決策について解説します。

2.理論

2-1.PyTorch の自動微分エンジンの基礎

PyTorch の最大の特徴の一つは、その自動微分エンジン、すなわち torch.autograd です。このエンジンは、ネットワークのパラメータに対する損失関数の勾配を自動的に計算し、効率的な学習を可能にします。勾配は、ニューラルネットワークを訓練する際に不可欠なもので、パラメータの最適化に用いられます。

2-2.Numpy と PyTorch の相互作用の問題

PyTorch のテンソルを Numpy 配列に変換してから処理を行うと、PyTorch の自動微分エンジンが勾配情報を失ってしまう問題があります。これは、Numpy が PyTorch の自動微分エンジンと互換性がないためです。つまり、Numpy で処理を行った結果を PyTorch のテンソルに戻しても、勾配情報が消失してしまうのです。

2-3.torch.autograd.Function の利用

この問題を解決するための鍵は、torch.autograd.Function を用いることです。これにより、カスタム損失関数内で独自の順伝播 (forward pass) と逆伝播 (backward pass) を定義することができます。

2-4.順伝播 (Forward Pass)

順伝播とは、入力データをネットワークを通して前方に伝播させ、出力を得るプロセスです。カスタム損失関数において、この段階では損失を計算します。

2-5.逆伝播 (Backward Pass)

逆伝播は、損失関数の勾配を計算し、これをネットワークのパラメータに逆伝播させるプロセスです。これにより、各パラメータの勾配が得られ、パラメータの更新が可能になります。

2-6.torch.autograd.Function の実装方法

  • サブクラスの作成: torch.autograd.Function のサブクラスを作成し、forward および backward メソッドを定義します。
  • forward メソッド: このメソッドでは、入力テンソルに対する損失計算を行います。必要に応じて、Numpy 配列に変換して計算を行うことができます。
  • backward メソッド: このメソッドでは、forward メソッドの出力に対する各入力の勾配を計算します。

3.実装例

3-1.問題設定

  • タスク:MNIST の画像生成モデル
  • 再現性を確保するためにシード値・学習時のミニバッチのシャッフルは False とする
  • 損失関数:平均二乗誤差
  • 通常の torch.nn.MSELoss と NumPy 配列による損失計算の値を比較
    torch.autograd.Function を適切に設定することで、カスタム損失においても自動微分およびモデルのパラメータ更新が正常に実行されていることを確認する)

3-2.再現性の確保

# ランダムシード
import random
import numpy as np
import torch

def fix_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# データのシャッフル
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False)

3-3.モデル

import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.ReLU(),
            nn.Linear(12, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

3-4.通常の実装

import torch.optim as optim

fix_seeds()
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 2
for epoch in range(num_epochs):
    for i, data in enumerate(train_loader):
        img, _ = data
        img = img.view(img.size(0), -1).to(device)
        output = model(img.float())
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        if i % 20 == 0:
            print(f'Step [{i+1:3}/{len(train_loader)}], Loss: {loss.item():.4f}')
        optimizer.step()

    print(f'=== Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
Step [  1/235], Loss: 0.2321
Step [ 21/235], Loss: 0.1319
Step [ 41/235], Loss: 0.0835
Step [ 61/235], Loss: 0.0688
Step [ 81/235], Loss: 0.0700
Step [101/235], Loss: 0.0665
Step [121/235], Loss: 0.0625
Step [141/235], Loss: 0.0624
Step [161/235], Loss: 0.0649
Step [181/235], Loss: 0.0608
Step [201/235], Loss: 0.0668
Step [221/235], Loss: 0.0600
=== Epoch [1/2], Loss: 0.0611
Step [  1/235], Loss: 0.0598
Step [ 21/235], Loss: 0.0625
Step [ 41/235], Loss: 0.0619
Step [ 61/235], Loss: 0.0565
Step [ 81/235], Loss: 0.0585
Step [101/235], Loss: 0.0550
Step [121/235], Loss: 0.0537
Step [141/235], Loss: 0.0538
Step [161/235], Loss: 0.0573
Step [181/235], Loss: 0.0545
Step [201/235], Loss: 0.0595
Step [221/235], Loss: 0.0555
=== Epoch [2/2], Loss: 0.0578

3-5.カスタム損失

ポイント

  • 勾配計算の理解: backward メソッドでは、損失関数の出力に対する各入力パラメータの勾配を計算します。これは、ニューラルネットワークのパラメータを更新するために使用される重要な情報です。勾配は、損失関数が入力パラメータに対してどのように変化するかを示します。

  • チェーンルールの適用: 逆伝播は、微分のチェーンルールに基づいています。複合関数の微分は、個々の関数の微分の積として表されます。したがって、各ステップで局所的な勾配を計算し、それらを連鎖させる必要があります。

平均二乗誤差は以下の式で定義されます。

MSE = \frac{1}{n}\sum_{i=1}^{n}(y_{i}-\widehat{y}_{i})^{2}

この損失関数に対する勾配は次のように計算されます。

\frac{\partial {MSE}}{\partial {\widehat{y}_{i}}}=2(y_{i}-\widehat{y}_{i})

これを元に、カスタム損失関数を定義します。

# customloss
class CustomLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, target):
        # PyTorch テンソルを NumPy 配列に変換
        input_np = input.cpu().detach().numpy()
        target_np = target.cpu().detach().numpy()
        # NumPy で二乗誤差を計算
        loss = np.mean((input_np - target_np) ** 2)
        # 損失計算を保存
        ctx.save_for_backward(input, target)
        # 損失を PyTorch テンソルに変換して返す
        return torch.tensor(loss, dtype=input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        # forwardで保存された入力とターゲットを取得
        input, target = ctx.saved_tensors
        # 勾配計算(grad_outputを使用してスケーリング)
        grad_input = grad_output * 2 * (input - target) / input.numel()
        return grad_input, None

# トレーニング
fix_seeds()
model = Autoencoder().to(device)
criterion = CustomLoss.apply
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 2
for epoch in range(num_epochs):
    for i, data in enumerate(train_loader):
        img, _ = data
        img = img.view(img.size(0), -1).to(device)
        output = model(img.float())
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        if i % 20 == 0:
            print(f'Step [{i+1:3}/{len(train_loader)}], Loss: {loss.item():.4f}')
        optimizer.step()

    print(f'=== Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
Step [  1/235], Loss: 0.2321
Step [ 21/235], Loss: 0.1319
Step [ 41/235], Loss: 0.0835
Step [ 61/235], Loss: 0.0688
Step [ 81/235], Loss: 0.0700
Step [101/235], Loss: 0.0665
Step [121/235], Loss: 0.0625
Step [141/235], Loss: 0.0624
Step [161/235], Loss: 0.0649
Step [181/235], Loss: 0.0608
Step [201/235], Loss: 0.0668
Step [221/235], Loss: 0.0600
=== Epoch [1/2], Loss: 0.0611
Step [  1/235], Loss: 0.0598
Step [ 21/235], Loss: 0.0625
Step [ 41/235], Loss: 0.0619
Step [ 61/235], Loss: 0.0565
Step [ 81/235], Loss: 0.0585
Step [101/235], Loss: 0.0550
Step [121/235], Loss: 0.0537
Step [141/235], Loss: 0.0538
Step [161/235], Loss: 0.0573
Step [181/235], Loss: 0.0545
Step [201/235], Loss: 0.0595
Step [221/235], Loss: 0.0555
=== Epoch [2/2], Loss: 0.0578

4.まとめ

PyTorch と Numpy を組み合わせたカスタム損失関数の作成は、自動微分エンジンとの互換性を保ちながら行う必要があります。torch.autograd.Function を適切に利用することで、この課題を解決し、より柔軟なディープラーニングモデルの開発が可能になります。

Discussion