🌊

【PyTorch】model解説

2024/04/06に公開

今回はPyTorchのmodelについて解説します。

1. modelとは

modelは、PyTrochのtorch.nn.Moduleを継承して定義されるオブジェクトです。
データに対する処理を記述する、機械学習モデルの本体です。

import torch.nn as nn
# 定義
class My_Model(nn.Module):

2. 定義

PyTorchのmodelは、init関数とforword関数を持ちます。
init関数ではレイヤの定義を行います。
forword関数ではデータの流れを記述します。引数として入力を取り、処理層を通った出力を返り値とします。

この時、バッチサイズは考慮せずに記述することができ、学習時にはDataLoaderによってバッチサイズが自動的にデータの最初の次元に追加されます。
例: (channels, height, width) → (batch_size, channels, height, width)

3. 例

シンプルな畳み込みモデルの例を示します。

import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5) # 畳み込み層 (入力チャネル, 出力チャネル, カーネルサイズ5×5)
        self.pool = nn.MaxPool2d(2, 2) # プーリング層 (2×2)
        self.fc1 = None # 全結合層 (後で定義)

    def forward(self, x):
        # 順伝播
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        # 全結合層の入力サイズを動的に計算
        num_flat_features = x.numel() // x.shape[0]
        
        # 全結合層を定義
        if self.fc1 is None:
            self.fc1 = nn.Linear(num_flat_features, 10)

        # 一次元化
        x = x.view(-1, num_flat_features)
        x = F.relu(self.fc1(x))
        return x

3.1 init関数

継承したnn.Moduleのinit関数を呼び出し、畳み込み層とプーリング層を定義しています。

3.2 forward関数

データの流れを記述しています。
引数に入力を取り、処理を行いその出力を返します。

以下でnumel()とview()を説明します。

numel()
  • numel()
    numelは、PyTorchのTensorクラスのメソッドで、テンソル内の要素の総数を返します。
    ・numelは"number of elements"の略称です。
    ・スカラーの場合は1を返します。
    ・ベクトルの場合はその長さを返します。
    ・多次元テンソルの場合は、すべての次元の要素数の積を返します。
    ・例
    import torch
    
    x = torch.tensor([1, 2, 3])
    print(x.numel())  # 出力: 3
      
    y = torch.tensor([[1, 2], [3, 4]])
    print(y.numel())  # 出力: 4
      
    z = torch.randn(2, 3, 4) 
    print(z.numel())  # 出力: 24 (2 * 3 * 4)
    
view()
  • view()
    viewもPyTorchのTensorクラスのメソッドで、テンソルの形状を変更するために使用されます。要素数が同じ場合、簡単にデータの形状を変更することができます。
    ・例1 基本
    x = torch.randn(2, 3, 4)
    print(x.shape)  # 出力: torch.Size([2, 3, 4])
      
    y = x.view(6, 4)
    print(y.shape)  # 出力: torch.Size([6, 4])
    
    ・例2 応用
    x = torch.randn(2, 3, 4)
    print(x.shape)  # 出力: torch.Size([2, 3, 4])
      
    y = x.view(-1, 12)
    print(y.shape)  # 出力: torch.Size([2, 12])
      
    z = x.view(2, -1)
    print(z.shape)  # 出力: torch.Size([2, 12])
    
    上記のように-1を使用して、より柔軟にデータ形状を変更することもできます。
    ・例3 注意点
    x = torch.arange(24).view(2, 3, 4)
    print(x)
    # 出力:
    # tensor([[[ 0,  1,  2,  3],
    #          [ 4,  5,  6,  7],
    #          [ 8,  9, 10, 11]],
    # 
    #         [[12, 13, 14, 15],
    #          [16, 17, 18, 19],
    #          [20, 21, 22, 23]]])
      
    y = x.view(2, 4, 3)
    print(y)
    # 出力:
    # tensor([[[ 0,  1,  2],
    #          [ 3,  4,  5],
    #          [ 6,  7,  8],
    #          [ 9, 10, 11]],
    # 
    #         [[12, 13, 14],
    #          [15, 16, 17],
    #          [18, 19, 20],
    #          [21, 22, 23]]])
    
    上記のように、viewでは高次元のデータから順番に番号が振られ、その番号をもとに形状を変更します。しかし、例えばこれが画像データだった場合、3のデータは右上にあったはずが、2段目の左に移動してしまっています。
    従って、view関数はどのようにデータが移動するのか確認して使う必要があります。

4. timmの活用

PyTorchでは、timm(Torchvision Image Models)と呼ばれる画像処理用のモデルを利用することができます。
これらは、非常に簡単に使用できます。

1. インストール

pipでインストールします

pip install timm

2. 定義

モデル名を指定して呼び出すだけで、モデルを定義できます。

import timm

# モデル名を指定してモデルを定義する
# pretrained=Trueを設定すると、事前に訓練された重みでモデルをロードする
model = timm.create_model('resnet50', pretrained=True)

また以下のコードで使用可能なモデルを確認できます。

import timm

# 使用可能なモデルのリストを表示
model_names = timm.list_models()
print(model_names)

3. カスタム

分類クラス数の変更や、転移学習に使用するために、timmモデルはカスタムを行うことができます。

・カスタム

import timm
import torch.nn as nn

# resnet50を定義
model = timm.create_model('resnet50', pretrained=True)

# 最終層をカスタマイズ。分類クラスを1000から10に変更
num_features = model.fc.in_features # 最終層(model.fc)の入力を取得
model.fc = nn.Linear(num_features, 10) # 最終層を変更

※model.fcが元のモデルで最終層として定義されている必要があります。

他にも畳み込み層の変更を行う場合は次のようになります。

import torchvision.models as models

# モデルをロード
model = models.resnet50(pretrained=True)

# 最初の畳み込み層にアクセス
first_conv_layer = model.conv1

# 新しい畳み込み層を作成し、置き換える
# 例:入力チャネル数が異なる場合(グレースケール画像用に変更するなど)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

5. 使い方

学習ではバッチサイズデータに対してmodelを実行し、その度にモデルのパラメータを更新します。

・例

# 学習
num_epochs = 10 # 全データを使用する回数

for epoch in range(num_epochs):
    for inputs, targets in train_dataloader: # バッチサイズの入力データと教師データを提供
        optimizer.zero_grad() # 勾配情報の消去
        outputs = model(inputs) # データをモデルに通す
        loss = criterion(outputs, targets) # 損失関数の定義
        loss.backward() # 逆伝播でパラメータ更新値を計算
        optimizer.step() # パラメータを更新

まとめ

PyTorchのmodelは、init関数とforward関数により定義されます。層の組み合わせで様々なモデルの構築が可能になりますが、データの形状が一致するように気をつける必要があります。
また学習時には、入力を与えるだけで良いのでコードを簡潔に記述することができます。

今回は以上です。

Discussion