Intro2DL:Pytorchの基礎 XORの例題をもとに
例題:XOR
PytorchにおけるNNの学習/推論を行うにあたり、XORの例題を使用します。
XORとは、
この例題を解くために、Pytorchのモジュールを使用していきます。
import os
import numpy as np
import time
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
モデル構築の流れ
基本的には以下の流れで実行します。
- モデルの定義
- DataLoaderの準備
- lossやoptimizerなどの準備
~ここから学習開始~ - DataLoaderからバッチの取得
- バッチをモデルへ入力し予測値を得る
- 予測値と実測値からLossを計算する
- ロスに関して各パラメータの勾配を計算する
- 勾配方向へモデルのパラメータを更新する
- 3~7の操作をイタレーションの回数行う
1. モデルの定義
XORの例題から、入力は
また、モデルは活性化関数がtanh, 隠れ層を1つだけ持つネットワークを考えます。
import torch
import torch.nn as nn
Pytorchのモデルは以下のような構成である必要があります。
class TemplateNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# xを入力とした時の計算
pass
class SimpleClassifier(nn.Module):
def __init__(self, num_in, num_hid, num_out):
super().__init__()
self.layer1 = nn.Linear(num_in, num_hid)
self.act_fn = nn.Tanh()
self.layer2 = nn.Linear(num_hid, num_out)
def forward(self, x):
x = self.layer1(x)
x = self.act_fn(x)
x = self.layer2(x)
return x
model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
print(model)
SimpleClassifier(
(layer1): Linear(in_features=2, out_features=4, bias=True)
(act_fn): Tanh()
(layer2): Linear(in_features=4, out_features=1, bias=True)
)
この時、モデルが持つパラメータを見てみます。Tanh
はパラメータを持たないので、layerのパラメータのみ表示されます。
for name, param in model.named_parameters():
print(f"Parameter {name}, shape {param.shape}")
Parameter layer1.weight, shape torch.Size([4, 2])
Parameter layer1.bias, shape torch.Size([4])
Parameter layer2.weight, shape torch.Size([1, 4])
Parameter layer2.bias, shape torch.Size([1])
注意点としては、layerの定義の仕方です。self.?
のような形で定義しないと以下のように認識/登録されないです。
class SimpleClassifier(nn.Module):
def __init__(self, num_in, num_hid, num_out):
super().__init__()
layer1 = nn.Linear(num_in, num_hid)
act_fn = nn.Tanh()
layer2 = nn.Linear(num_hid, num_out)
self.list_layer = [layer1, act_fn, layer2]
def forward(self, x):
for layer in self.list_layer:
x = layer(x)
return x
model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
print(model)
for name, param in model.named_parameters():
print(f"Parameter {name}, shape {param.shape}")
SimpleClassifier()
上記のように定義したい場合は、nn.ModuleList
やnn.ModuleDict
, nn.Sequential
を使用します。ここでは、nn.ModuleList
の例のみを見ます。
class SimpleClassifier(nn.Module):
def __init__(self, num_in, num_hid, num_out):
super().__init__()
layer1 = nn.Linear(num_in, num_hid)
act_fn = nn.Tanh()
layer2 = nn.Linear(num_hid, num_out)
self.list_layer = nn.ModuleList([layer1, act_fn, layer2])
def forward(self, x):
for layer in self.list_layer:
x = layer(x)
return x
model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
print(model)
for name, param in model.named_parameters():
print(f"Parameter {name}, shape {param.shape}")
SimpleClassifier(
(list_layer): ModuleList(
(0): Linear(in_features=2, out_features=4, bias=True)
(1): Tanh()
(2): Linear(in_features=4, out_features=1, bias=True)
)
)
Parameter list_layer.0.weight, shape torch.Size([4, 2])
Parameter list_layer.0.bias, shape torch.Size([4])
Parameter list_layer.2.weight, shape torch.Size([1, 4])
Parameter list_layer.2.bias, shape torch.Size([1])
2. DataLoaderの準備
pytorchでデータを扱う際は、Dataset
とDataLoader
を使用します。
シンプルには、Dataset
はi番目のデータを取得するためのクラスで、DataLoader
はバッチ処理などを効率的に実装できるクラスです。
# dataを効率的に扱うモジュールをインポート
import torch.utils.data as data
Dataset
クラスはi番目のデータを返す__getitem__()
とデータのサイズを返す__len__()
を持ちます。
class XORDataset(data.Dataset):
def __init__(self, size, std=0.1):
"""
Inputs:
size - Number of data points we want to generate
std - Standard deviation of the noise (see generate_continuous_xor function)
"""
super().__init__()
self.size = size
self.std = std
self.generate_continuous_xor()
def generate_continuous_xor(self):
data = torch.randint(low=0, high=2, size=(self.size, 2), dtype=torch.float32)
label = (data.sum(dim=1) == 1).to(torch.long)
data += self.std * torch.randn(data.shape)
self.data = data
self.label = label
def __len__(self):
return self.size
def __getitem__(self, idx):
data_point = self.data[idx]
data_label = self.label[idx]
return data_point, data_label
dataset = XORDataset(size=2500)
print(dataset.size)
print(dataset[0])
2500
(tensor([ 0.9450, -0.0377]), tensor(1))
DataLoader
クラスは上記で定義したDataset
の__getitem__
を使用して、バッチ処理などをよしなに実行してくれます。
オプションは以下の通りです。
- batch_size: バッチサイズを指定します
- shuffle: データセットの並び順をシャッフルするかどうか
- pin_memory: GPU上のメモリにデータをコピーします。サイズが大きい時には有効ですが、GPUのメモリを消費するので単なる検証や推論の時には必要ないです。
- drop_last: batch_sizeでデータの数を割り切れない時の余りを使用するかどうか(訓練時のみバッチサイズを一定に保つために必要)
data_loader = data.DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)
print(data_loader)
data_inputs, data_labels = next(iter(data_loader))
print("Data inputs", data_inputs.shape, "\n", data_inputs)
print("Data labels", data_labels.shape, "\n", data_labels)
<torch.utils.data.dataloader.DataLoader object at 0x108bbde20>
Data inputs torch.Size([128, 2])
tensor([[-7.2804e-02, 1.1041e+00],
[ 1.1450e+00, 1.7077e-02],
[ 4.0985e-02, -5.4373e-02],
[ 7.5966e-02, 9.7802e-01],
[ 2.2727e-02, 8.9922e-01],
[ 1.2303e-01, 1.0085e+00],
[-7.5358e-02, 7.5478e-02],
[ 1.0237e+00, 1.1585e+00],
[ 1.0081e+00, -6.0899e-02],
[ 1.0971e+00, 8.7675e-01],
[ 1.2261e+00, 1.0328e+00],
[-3.4534e-02, 8.2188e-01],
[-1.6522e-02, 9.7790e-01],
[ 9.6733e-01, 1.0362e+00],
[ 1.0112e-01, 8.8916e-01],
[-5.9085e-02, 8.7525e-01],
[ 2.0926e-02, 1.0810e-01],
[ 1.6433e-02, -9.2412e-03],
[ 1.0063e+00, 9.2770e-01],
[ 2.3867e-02, 9.4516e-01],
[ 9.3029e-01, 1.0644e+00],
[ 1.0317e+00, -5.0927e-04],
[-7.4040e-02, 1.1017e+00],
[-5.3543e-04, 2.0390e-02],
[ 7.6709e-02, 9.2210e-01],
[ 1.9267e-02, 1.0064e-01],
[-1.4015e-01, 7.9226e-01],
[ 2.3145e-01, 9.9505e-01],
[ 9.6408e-01, 4.8191e-02],
[ 6.9211e-02, -7.1338e-02],
[ 3.2373e-02, 1.0113e+00],
[ 1.2508e+00, 8.0529e-02],
[ 1.0965e+00, -1.4765e-01],
[-4.0061e-02, -5.2052e-02],
[-2.0837e-01, 8.3516e-01],
[ 1.0791e+00, -7.4755e-02],
[ 7.9891e-01, 2.0973e-01],
[ 9.0706e-01, 7.6696e-01],
[ 9.1724e-01, 9.1130e-01],
[ 1.1240e+00, -2.8604e-02],
[-1.5656e-01, -6.1603e-02],
[ 3.9071e-02, 1.0703e+00],
[ 1.1136e-01, 7.6418e-02],
[-4.0319e-02, 9.1122e-01],
[-1.2613e-01, 9.2346e-01],
[ 8.6311e-01, 4.1220e-03],
[-1.7388e-01, 1.1669e-01],
[-7.4471e-02, -1.9813e-01],
[ 1.0433e+00, 1.0701e+00],
[ 1.2681e-01, 1.0351e+00],
[-7.1518e-02, -2.9065e-01],
[-3.2377e-02, 1.0505e+00],
[-2.6789e-01, 1.2169e-01],
[ 1.8959e-01, -1.2496e-01],
[ 2.9532e-02, 1.2342e-01],
[ 8.3225e-01, -9.2726e-02],
[ 9.6130e-01, 1.0646e+00],
[ 9.1031e-01, 9.7370e-01],
[ 1.0166e+00, 1.0392e+00],
[ 2.0528e-01, -1.2846e-02],
[ 1.0167e+00, 9.3552e-01],
[ 8.8396e-01, 9.7615e-01],
[ 1.0432e+00, 1.0095e+00],
[ 9.9360e-01, 1.0981e+00],
[ 1.3962e-03, 1.0370e+00],
[ 3.1574e-02, 7.6176e-03],
[ 1.1352e+00, 1.0465e+00],
[ 8.2953e-03, -9.6899e-02],
[ 1.0000e+00, 3.2121e-02],
[ 9.7082e-02, 8.8162e-01],
[ 5.3385e-02, 1.0033e-01],
[ 9.5098e-01, 1.3654e-02],
[ 6.9419e-02, 1.0831e+00],
[ 1.0508e+00, 1.0540e+00],
[ 2.6655e-02, 8.9375e-01],
[-1.0887e-01, 2.2077e-02],
[ 8.7785e-01, 9.4118e-01],
[-4.8292e-03, -1.5647e-01],
[ 7.2836e-01, 6.2586e-02],
[-6.3113e-03, 2.2185e-02],
[-1.3274e-01, 8.6549e-01],
[ 1.0961e+00, 1.0224e+00],
[ 1.0638e+00, 1.5344e-02],
[-9.7978e-02, 1.6069e-01],
[ 1.0346e+00, 1.1706e+00],
[ 1.2919e-01, -8.1029e-02],
[ 1.0119e+00, 9.2153e-01],
[-2.8104e-02, -2.3104e-02],
[ 1.1002e-01, 1.0594e+00],
[ 2.3332e-01, -8.6315e-02],
[ 1.4667e-01, 9.6280e-01],
[ 1.1106e+00, 8.0643e-01],
[ 1.2023e-01, 1.0180e+00],
[-6.0721e-02, 1.1322e+00],
[ 8.6859e-01, 9.8454e-01],
[ 1.0771e-01, -8.9369e-02],
[ 4.4324e-02, 1.0185e+00],
[-8.9303e-02, 9.8590e-01],
[ 9.6519e-01, 1.0608e+00],
[ 9.3951e-01, -5.4688e-02],
[ 2.2585e-02, 9.2614e-01],
[-7.6468e-02, 2.1452e-01],
[ 7.2733e-02, -7.3930e-02],
[ 1.1124e+00, 1.1067e+00],
[ 2.8434e-01, -9.9704e-02],
[ 1.0294e+00, -1.8035e-01],
[-2.8473e-02, 9.1512e-01],
[ 1.9851e-01, 1.5367e-03],
[ 1.1980e-01, 8.8696e-01],
[ 2.0547e-01, 1.0220e+00],
[-1.9407e-01, 8.4612e-01],
[-1.5697e-01, -1.0190e-01],
[-8.0814e-02, 1.1973e-01],
[ 1.0253e+00, 9.5834e-01],
[ 1.0403e+00, 1.0607e+00],
[ 9.2056e-01, 6.0310e-02],
[ 9.5124e-01, 1.2482e+00],
[-2.2171e-02, 9.7395e-02],
[-2.3042e-01, 3.6538e-02],
[ 8.7662e-02, 9.6314e-01],
[ 6.3132e-03, 7.1874e-02],
[ 1.6658e-01, 1.4259e-01],
[ 9.4618e-01, 8.5934e-01],
[ 1.1190e+00, 9.6802e-01],
[ 7.7979e-02, 9.1017e-01],
[ 1.0036e+00, 1.1159e+00],
[ 1.0358e+00, 9.7519e-01],
[ 2.8420e-02, -3.8281e-02]])
Data labels torch.Size([128])
tensor([1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0,
1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0,
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0,
1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 1, 0, 0, 0])
3. lossやoptimizerなどの準備
Lossの定義
今回は2値分類なので、Binary Cross Entropy (BCE)ロスを使用します。pytorchではBCE lossはnn.BCELoss()
とnn.BCEWithLogitsLoss()
の2種類あります。nn.BCEWithLogitsLoss()
はSigmoidの層とBCE lossが1つのクラスになったもので、1つに結合することでlog-sum-exp trickにより数値的に安定します。そのため、今回はモデルの定義のところではSigmoidの層を抜きにして、ロスに``nn.BCEWithLogitsLoss()`を使用します。
loss_module = nn.BCEWithLogitsLoss()
otimizzerの定義
様々な最適化手法がありますが、今回は確率的勾配降下法(Stochastic Gradient Descent; SGD)を使用します。勾配の更新レベルを決める学習率を指定する必要がありますが、今回の小さなネットワークには0.1を設定します。
optimizerは.step()
と.zero_grad()
という関数を持ちます。.step()
関数は計算された勾配情報を元にパラメータを更新します。.zero_grad()
関数はすべてのパラメータの勾配情報を0にします。これがなぜ必要なのかというと、勾配情報は計算されるたびに前回計算された値に加算されていくためです。そのため、.backward()
の前に必ず.zero_grad()
を行う必要があります。
# ここで、modelのパラメータを入力します
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
モデルをGPUへ移す
データをGPUにPushします。ここで、モデルは1回きり行えば大丈夫です。
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)
Device cpu
model.to(device)
SimpleClassifier(
(list_layer): ModuleList(
(0): Linear(in_features=2, out_features=4, bias=True)
(1): Tanh()
(2): Linear(in_features=4, out_features=1, bias=True)
)
)
4-9. 学習
学習では以下のSTEPを行います。
- DataLoaderからバッチの取得
- バッチをモデルへ入力し予測値を得る
- 予測値と実測値からLossを計算する
- ロスに関して各パラメータの勾配を計算する
- 勾配方向へモデルのパラメータを更新する
- 3~7の操作をイタレーションの回数行う
訓練中はmodel.train()
によりモデルをtrainのmodeにします。これは、dropout
やBatchNorm
などが学習時と推論時で挙動が異なるためです。推論時はmodel.eval()
を実行します。
def train_model(model, optimizer, loss_module, data_loader, num_epochs=100):
# modelをtrain modeに設定する
model.train()
for epoch in tqdm(range(num_epochs)):
# 4. ここからbatch処理が走る
for data_inputs, data_labels in data_loader:
# dataをdeviceに移す
data_inputs = data_inputs.to(device)
data_labels = data_labels.to(device)
# 5. バッチをモデルへ入力し予測値を得る
preds = model(data_inputs)
# [Batch size, 1] -> [Batch size]へ変換
preds = preds.squeeze(dim=1)
# 6. Lossを計算する
loss = loss_module(preds, data_labels.float())
# 7. 勾配を計算する
# 必ず勾配計算前にzero_grad()を実行する
optimizer.zero_grad()
loss.backward()
## 8. 勾配情報をもとにパラメータを更新する
optimizer.step()
train_model(model, optimizer, loss_module, data_loader, 100)
100%|██████████| 100/100 [00:00<00:00, 158.24it/s]
モデルの保存と読み込み
model.state_dict()
に学習可能なパラメータを取得できます。これを用いて、torch.save()
により保存します。
また、torch.load()
によりパラメータを読み込むことができます。
# 学習可能なパラメータを取得
state_dict = model.state_dict()
print(state_dict)
# 保存
torch.save(state_dict, "our_model.tar")
OrderedDict([('list_layer.0.weight', tensor([[ 2.5581, 2.6572],
[ 2.1702, -1.7598],
[ 2.0181, -2.4040],
[ 1.8723, 1.7103]])), ('list_layer.0.bias', tensor([-1.0644, 0.8969, -1.0651, -2.7404])), ('list_layer.2.weight', tensor([[ 3.7602, -2.7670, 2.9931, -3.2199]])), ('list_layer.2.bias', tensor([-0.6014]))])
# 上記で保存したファイルを取得
state_dict = torch.load("our_model.tar")
# modelを作成し、パラメータを読み込む
new_model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
new_model.load_state_dict(state_dict)
<All keys matched successfully>
推論
推論時は、計算グラフ/勾配を作成/計算する必要もないため、メモリの削減と速度アップのために勾配計算をしません。計算グラフを構築しないためには、with torch.no_grad():
を使用します。
test_dataset = XORDataset(size=500)
# 推論時なので、shuffle=False, drop_last=Falseを指定しておきます
test_data_loader = data.DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False)
def eval_model(model, data_loader):
# 推論モードに設定
model.eval()
true_preds, num_preds = 0., 0.
# これ以降は計算グラフを作成しない
with torch.no_grad():
for data_inputs, data_labels in data_loader:
# 推論
data_inputs, data_labels = data_inputs.to(device), data_labels.to(device)
preds = model(data_inputs)
preds = preds.squeeze(dim=1)
# sigmoidにより0~1の間に変換する
preds = torch.sigmoid(preds)
# 0か1へ変換する
pred_labels = (preds >= 0.5).long()
# Accuracyの計算
true_preds += (pred_labels == data_labels).sum()
num_preds += data_labels.shape[0]
acc = true_preds / num_preds
print(f"Accuracy of the model: {100.0*acc:4.2f}%")
# 今回は非常に簡単なデータなので、Acc.=100になっています
eval_model(model, test_data_loader)
Accuracy of the model: 100.00%
以上で例題を用いたpytorchの基礎は終わりです。
次回以降、各パートの詳細を見ていく予定です。
Discussion