【PyTorch】model解説
今回はPyTorchのmodelについて解説します。
1. modelとは
modelは、PyTrochのtorch.nn.Module
を継承して定義されるオブジェクトです。
データに対する処理を記述する、機械学習モデルの本体です。
import torch.nn as nn
# 定義
class My_Model(nn.Module):
2. 定義
PyTorchのmodelは、init関数とforword関数を持ちます。
init関数ではレイヤの定義を行います。
forword関数ではデータの流れを記述します。引数として入力を取り、処理層を通った出力を返り値とします。
この時、バッチサイズは考慮せずに記述することができ、学習時にはDataLoaderによってバッチサイズが自動的にデータの最初の次元に追加されます。
例: (channels, height, width) → (batch_size, channels, height, width)
3. 例
シンプルな畳み込みモデルの例を示します。
import torch.nn as nn
import torch.nn.functional as F
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5) # 畳み込み層 (入力チャネル, 出力チャネル, カーネルサイズ5×5)
self.pool = nn.MaxPool2d(2, 2) # プーリング層 (2×2)
self.fc1 = None # 全結合層 (後で定義)
def forward(self, x):
# 順伝播
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
# 全結合層の入力サイズを動的に計算
num_flat_features = x.numel() // x.shape[0]
# 全結合層を定義
if self.fc1 is None:
self.fc1 = nn.Linear(num_flat_features, 10)
# 一次元化
x = x.view(-1, num_flat_features)
x = F.relu(self.fc1(x))
return x
3.1 init関数
継承したnn.Moduleのinit関数を呼び出し、畳み込み層とプーリング層を定義しています。
3.2 forward関数
データの流れを記述しています。
引数に入力を取り、処理を行いその出力を返します。
以下でnumel()とview()を説明します。
numel()
- numel()
numelは、PyTorchのTensorクラスのメソッドで、テンソル内の要素の総数を返します。
・numelは"number of elements"の略称です。
・スカラーの場合は1を返します。
・ベクトルの場合はその長さを返します。
・多次元テンソルの場合は、すべての次元の要素数の積を返します。
・例import torch x = torch.tensor([1, 2, 3]) print(x.numel()) # 出力: 3 y = torch.tensor([[1, 2], [3, 4]]) print(y.numel()) # 出力: 4 z = torch.randn(2, 3, 4) print(z.numel()) # 出力: 24 (2 * 3 * 4)
view()
- view()
viewもPyTorchのTensorクラスのメソッドで、テンソルの形状を変更するために使用されます。要素数が同じ場合、簡単にデータの形状を変更することができます。
・例1 基本・例2 応用x = torch.randn(2, 3, 4) print(x.shape) # 出力: torch.Size([2, 3, 4]) y = x.view(6, 4) print(y.shape) # 出力: torch.Size([6, 4])
上記のようにx = torch.randn(2, 3, 4) print(x.shape) # 出力: torch.Size([2, 3, 4]) y = x.view(-1, 12) print(y.shape) # 出力: torch.Size([2, 12]) z = x.view(2, -1) print(z.shape) # 出力: torch.Size([2, 12])
-1
を使用して、より柔軟にデータ形状を変更することもできます。
・例3 注意点上記のように、viewでは高次元のデータから順番に番号が振られ、その番号をもとに形状を変更します。しかし、例えばこれが画像データだった場合、3のデータは右上にあったはずが、2段目の左に移動してしまっています。x = torch.arange(24).view(2, 3, 4) print(x) # 出力: # tensor([[[ 0, 1, 2, 3], # [ 4, 5, 6, 7], # [ 8, 9, 10, 11]], # # [[12, 13, 14, 15], # [16, 17, 18, 19], # [20, 21, 22, 23]]]) y = x.view(2, 4, 3) print(y) # 出力: # tensor([[[ 0, 1, 2], # [ 3, 4, 5], # [ 6, 7, 8], # [ 9, 10, 11]], # # [[12, 13, 14], # [15, 16, 17], # [18, 19, 20], # [21, 22, 23]]])
従って、view関数はどのようにデータが移動するのか確認して使う必要があります。
4. timmの活用
PyTorchでは、timm(Torchvision Image Models)と呼ばれる画像処理用のモデルを利用することができます。
これらは、非常に簡単に使用できます。
1. インストール
pipでインストールします
pip install timm
2. 定義
モデル名を指定して呼び出すだけで、モデルを定義できます。
import timm
# モデル名を指定してモデルを定義する
# pretrained=Trueを設定すると、事前に訓練された重みでモデルをロードする
model = timm.create_model('resnet50', pretrained=True)
また以下のコードで使用可能なモデルを確認できます。
import timm
# 使用可能なモデルのリストを表示
model_names = timm.list_models()
print(model_names)
3. カスタム
分類クラス数の変更や、転移学習に使用するために、timmモデルはカスタムを行うことができます。
・カスタム
import timm
import torch.nn as nn
# resnet50を定義
model = timm.create_model('resnet50', pretrained=True)
# 最終層をカスタマイズ。分類クラスを1000から10に変更
num_features = model.fc.in_features # 最終層(model.fc)の入力を取得
model.fc = nn.Linear(num_features, 10) # 最終層を変更
※model.fcが元のモデルで最終層として定義されている必要があります。
他にも畳み込み層の変更を行う場合は次のようになります。
import torchvision.models as models
# モデルをロード
model = models.resnet50(pretrained=True)
# 最初の畳み込み層にアクセス
first_conv_layer = model.conv1
# 新しい畳み込み層を作成し、置き換える
# 例:入力チャネル数が異なる場合(グレースケール画像用に変更するなど)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
5. 使い方
学習ではバッチサイズデータに対してmodelを実行し、その度にモデルのパラメータを更新します。
・例
# 学習
num_epochs = 10 # 全データを使用する回数
for epoch in range(num_epochs):
for inputs, targets in train_dataloader: # バッチサイズの入力データと教師データを提供
optimizer.zero_grad() # 勾配情報の消去
outputs = model(inputs) # データをモデルに通す
loss = criterion(outputs, targets) # 損失関数の定義
loss.backward() # 逆伝播でパラメータ更新値を計算
optimizer.step() # パラメータを更新
まとめ
PyTorchのmodelは、init関数とforward関数により定義されます。層の組み合わせで様々なモデルの構築が可能になりますが、データの形状が一致するように気をつける必要があります。
また学習時には、入力を与えるだけで良いのでコードを簡潔に記述することができます。
今回は以上です。
Discussion