🕌

Intro2DL:Pytorchの基礎 XORの例題をもとに

2022/09/07に公開

例題:XOR

PytorchにおけるNNの学習/推論を行うにあたり、XORの例題を使用します。
XORとは、x_1x_2の片方が1で残りが0であればラベル1、それら以外のケースであればラベル0を入力するような問題で、単純な線形関数ではうまく推論できないです。

この例題を解くために、Pytorchのモジュールを使用していきます。

import os
import numpy as np 
import time

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

モデル構築の流れ

基本的には以下の流れで実行します。

  1. モデルの定義
  2. DataLoaderの準備
  3. lossやoptimizerなどの準備
    ~ここから学習開始~
  4. DataLoaderからバッチの取得
  5. バッチをモデルへ入力し予測値を得る
  6. 予測値と実測値からLossを計算する
  7. ロスに関して各パラメータの勾配を計算する
  8. 勾配方向へモデルのパラメータを更新する
  9. 3~7の操作をイタレーションの回数行う

1. モデルの定義

XORの例題から、入力はx_1x_2の2変数で、ラベルは01のみです。
また、モデルは活性化関数がtanh, 隠れ層を1つだけ持つネットワークを考えます。

import torch
import torch.nn as nn

Pytorchのモデルは以下のような構成である必要があります。

class TemplateNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        # xを入力とした時の計算
        pass
class SimpleClassifier(nn.Module):

    def __init__(self, num_in, num_hid, num_out):
        super().__init__()
        self.layer1 = nn.Linear(num_in, num_hid)
        self.act_fn = nn.Tanh()
        self.layer2 = nn.Linear(num_hid, num_out)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.act_fn(x)
        x = self.layer2(x)

        return x
model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
print(model)
SimpleClassifier(
  (layer1): Linear(in_features=2, out_features=4, bias=True)
  (act_fn): Tanh()
  (layer2): Linear(in_features=4, out_features=1, bias=True)
)

この時、モデルが持つパラメータを見てみます。Tanhはパラメータを持たないので、layerのパラメータのみ表示されます。

for name, param in model.named_parameters():
    print(f"Parameter {name}, shape {param.shape}")
Parameter layer1.weight, shape torch.Size([4, 2])
Parameter layer1.bias, shape torch.Size([4])
Parameter layer2.weight, shape torch.Size([1, 4])
Parameter layer2.bias, shape torch.Size([1])

注意点としては、layerの定義の仕方です。self.?のような形で定義しないと以下のように認識/登録されないです。

class SimpleClassifier(nn.Module):

    def __init__(self, num_in, num_hid, num_out):
        super().__init__()
        layer1 = nn.Linear(num_in, num_hid)
        act_fn = nn.Tanh()
        layer2 = nn.Linear(num_hid, num_out)
        self.list_layer = [layer1, act_fn, layer2]
    
    def forward(self, x):
        for layer in self.list_layer:
            x = layer(x)

        return x

model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
print(model)

for name, param in model.named_parameters():
    print(f"Parameter {name}, shape {param.shape}")
SimpleClassifier()

上記のように定義したい場合は、nn.ModuleListnn.ModuleDict, nn.Sequentialを使用します。ここでは、nn.ModuleListの例のみを見ます。

class SimpleClassifier(nn.Module):

    def __init__(self, num_in, num_hid, num_out):
        super().__init__()
        layer1 = nn.Linear(num_in, num_hid)
        act_fn = nn.Tanh()
        layer2 = nn.Linear(num_hid, num_out)
        self.list_layer = nn.ModuleList([layer1, act_fn, layer2])
    
    def forward(self, x):
        for layer in self.list_layer:
            x = layer(x)

        return x

model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
print(model)

for name, param in model.named_parameters():
    print(f"Parameter {name}, shape {param.shape}")
SimpleClassifier(
  (list_layer): ModuleList(
    (0): Linear(in_features=2, out_features=4, bias=True)
    (1): Tanh()
    (2): Linear(in_features=4, out_features=1, bias=True)
  )
)
Parameter list_layer.0.weight, shape torch.Size([4, 2])
Parameter list_layer.0.bias, shape torch.Size([4])
Parameter list_layer.2.weight, shape torch.Size([1, 4])
Parameter list_layer.2.bias, shape torch.Size([1])

2. DataLoaderの準備

pytorchでデータを扱う際は、DatasetDataLoaderを使用します。
シンプルには、Datasetはi番目のデータを取得するためのクラスで、DataLoaderはバッチ処理などを効率的に実装できるクラスです。

# dataを効率的に扱うモジュールをインポート
import torch.utils.data as data

Datasetクラスはi番目のデータを返す__getitem__()とデータのサイズを返す__len__()を持ちます。

class XORDataset(data.Dataset):

    def __init__(self, size, std=0.1):
        """
        Inputs:
            size - Number of data points we want to generate
            std - Standard deviation of the noise (see generate_continuous_xor function)
        """
        super().__init__()
        self.size = size
        self.std = std
        self.generate_continuous_xor()

    def generate_continuous_xor(self):
        data = torch.randint(low=0, high=2, size=(self.size, 2), dtype=torch.float32)
        label = (data.sum(dim=1) == 1).to(torch.long)
        data += self.std * torch.randn(data.shape)

        self.data = data
        self.label = label

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        data_point = self.data[idx]
        data_label = self.label[idx]
        return data_point, data_label
dataset = XORDataset(size=2500)
print(dataset.size)
print(dataset[0])
2500
(tensor([ 0.9450, -0.0377]), tensor(1))

DataLoaderクラスは上記で定義したDataset__getitem__を使用して、バッチ処理などをよしなに実行してくれます。

オプションは以下の通りです。

  • batch_size: バッチサイズを指定します
  • shuffle: データセットの並び順をシャッフルするかどうか
  • pin_memory: GPU上のメモリにデータをコピーします。サイズが大きい時には有効ですが、GPUのメモリを消費するので単なる検証や推論の時には必要ないです。
  • drop_last: batch_sizeでデータの数を割り切れない時の余りを使用するかどうか(訓練時のみバッチサイズを一定に保つために必要)
data_loader = data.DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)
print(data_loader)

data_inputs, data_labels = next(iter(data_loader))
print("Data inputs", data_inputs.shape, "\n", data_inputs)
print("Data labels", data_labels.shape, "\n", data_labels)
<torch.utils.data.dataloader.DataLoader object at 0x108bbde20>
Data inputs torch.Size([128, 2]) 
 tensor([[-7.2804e-02,  1.1041e+00],
        [ 1.1450e+00,  1.7077e-02],
        [ 4.0985e-02, -5.4373e-02],
        [ 7.5966e-02,  9.7802e-01],
        [ 2.2727e-02,  8.9922e-01],
        [ 1.2303e-01,  1.0085e+00],
        [-7.5358e-02,  7.5478e-02],
        [ 1.0237e+00,  1.1585e+00],
        [ 1.0081e+00, -6.0899e-02],
        [ 1.0971e+00,  8.7675e-01],
        [ 1.2261e+00,  1.0328e+00],
        [-3.4534e-02,  8.2188e-01],
        [-1.6522e-02,  9.7790e-01],
        [ 9.6733e-01,  1.0362e+00],
        [ 1.0112e-01,  8.8916e-01],
        [-5.9085e-02,  8.7525e-01],
        [ 2.0926e-02,  1.0810e-01],
        [ 1.6433e-02, -9.2412e-03],
        [ 1.0063e+00,  9.2770e-01],
        [ 2.3867e-02,  9.4516e-01],
        [ 9.3029e-01,  1.0644e+00],
        [ 1.0317e+00, -5.0927e-04],
        [-7.4040e-02,  1.1017e+00],
        [-5.3543e-04,  2.0390e-02],
        [ 7.6709e-02,  9.2210e-01],
        [ 1.9267e-02,  1.0064e-01],
        [-1.4015e-01,  7.9226e-01],
        [ 2.3145e-01,  9.9505e-01],
        [ 9.6408e-01,  4.8191e-02],
        [ 6.9211e-02, -7.1338e-02],
        [ 3.2373e-02,  1.0113e+00],
        [ 1.2508e+00,  8.0529e-02],
        [ 1.0965e+00, -1.4765e-01],
        [-4.0061e-02, -5.2052e-02],
        [-2.0837e-01,  8.3516e-01],
        [ 1.0791e+00, -7.4755e-02],
        [ 7.9891e-01,  2.0973e-01],
        [ 9.0706e-01,  7.6696e-01],
        [ 9.1724e-01,  9.1130e-01],
        [ 1.1240e+00, -2.8604e-02],
        [-1.5656e-01, -6.1603e-02],
        [ 3.9071e-02,  1.0703e+00],
        [ 1.1136e-01,  7.6418e-02],
        [-4.0319e-02,  9.1122e-01],
        [-1.2613e-01,  9.2346e-01],
        [ 8.6311e-01,  4.1220e-03],
        [-1.7388e-01,  1.1669e-01],
        [-7.4471e-02, -1.9813e-01],
        [ 1.0433e+00,  1.0701e+00],
        [ 1.2681e-01,  1.0351e+00],
        [-7.1518e-02, -2.9065e-01],
        [-3.2377e-02,  1.0505e+00],
        [-2.6789e-01,  1.2169e-01],
        [ 1.8959e-01, -1.2496e-01],
        [ 2.9532e-02,  1.2342e-01],
        [ 8.3225e-01, -9.2726e-02],
        [ 9.6130e-01,  1.0646e+00],
        [ 9.1031e-01,  9.7370e-01],
        [ 1.0166e+00,  1.0392e+00],
        [ 2.0528e-01, -1.2846e-02],
        [ 1.0167e+00,  9.3552e-01],
        [ 8.8396e-01,  9.7615e-01],
        [ 1.0432e+00,  1.0095e+00],
        [ 9.9360e-01,  1.0981e+00],
        [ 1.3962e-03,  1.0370e+00],
        [ 3.1574e-02,  7.6176e-03],
        [ 1.1352e+00,  1.0465e+00],
        [ 8.2953e-03, -9.6899e-02],
        [ 1.0000e+00,  3.2121e-02],
        [ 9.7082e-02,  8.8162e-01],
        [ 5.3385e-02,  1.0033e-01],
        [ 9.5098e-01,  1.3654e-02],
        [ 6.9419e-02,  1.0831e+00],
        [ 1.0508e+00,  1.0540e+00],
        [ 2.6655e-02,  8.9375e-01],
        [-1.0887e-01,  2.2077e-02],
        [ 8.7785e-01,  9.4118e-01],
        [-4.8292e-03, -1.5647e-01],
        [ 7.2836e-01,  6.2586e-02],
        [-6.3113e-03,  2.2185e-02],
        [-1.3274e-01,  8.6549e-01],
        [ 1.0961e+00,  1.0224e+00],
        [ 1.0638e+00,  1.5344e-02],
        [-9.7978e-02,  1.6069e-01],
        [ 1.0346e+00,  1.1706e+00],
        [ 1.2919e-01, -8.1029e-02],
        [ 1.0119e+00,  9.2153e-01],
        [-2.8104e-02, -2.3104e-02],
        [ 1.1002e-01,  1.0594e+00],
        [ 2.3332e-01, -8.6315e-02],
        [ 1.4667e-01,  9.6280e-01],
        [ 1.1106e+00,  8.0643e-01],
        [ 1.2023e-01,  1.0180e+00],
        [-6.0721e-02,  1.1322e+00],
        [ 8.6859e-01,  9.8454e-01],
        [ 1.0771e-01, -8.9369e-02],
        [ 4.4324e-02,  1.0185e+00],
        [-8.9303e-02,  9.8590e-01],
        [ 9.6519e-01,  1.0608e+00],
        [ 9.3951e-01, -5.4688e-02],
        [ 2.2585e-02,  9.2614e-01],
        [-7.6468e-02,  2.1452e-01],
        [ 7.2733e-02, -7.3930e-02],
        [ 1.1124e+00,  1.1067e+00],
        [ 2.8434e-01, -9.9704e-02],
        [ 1.0294e+00, -1.8035e-01],
        [-2.8473e-02,  9.1512e-01],
        [ 1.9851e-01,  1.5367e-03],
        [ 1.1980e-01,  8.8696e-01],
        [ 2.0547e-01,  1.0220e+00],
        [-1.9407e-01,  8.4612e-01],
        [-1.5697e-01, -1.0190e-01],
        [-8.0814e-02,  1.1973e-01],
        [ 1.0253e+00,  9.5834e-01],
        [ 1.0403e+00,  1.0607e+00],
        [ 9.2056e-01,  6.0310e-02],
        [ 9.5124e-01,  1.2482e+00],
        [-2.2171e-02,  9.7395e-02],
        [-2.3042e-01,  3.6538e-02],
        [ 8.7662e-02,  9.6314e-01],
        [ 6.3132e-03,  7.1874e-02],
        [ 1.6658e-01,  1.4259e-01],
        [ 9.4618e-01,  8.5934e-01],
        [ 1.1190e+00,  9.6802e-01],
        [ 7.7979e-02,  9.1017e-01],
        [ 1.0036e+00,  1.1159e+00],
        [ 1.0358e+00,  9.7519e-01],
        [ 2.8420e-02, -3.8281e-02]])
Data labels torch.Size([128]) 
 tensor([1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0,
        1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0,
        0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
        1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0,
        1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1,
        0, 0, 0, 0, 1, 0, 0, 0])

3. lossやoptimizerなどの準備

Lossの定義

今回は2値分類なので、Binary Cross Entropy (BCE)ロスを使用します。pytorchではBCE lossはnn.BCELoss()nn.BCEWithLogitsLoss()の2種類あります。nn.BCEWithLogitsLoss()はSigmoidの層とBCE lossが1つのクラスになったもので、1つに結合することでlog-sum-exp trickにより数値的に安定します。そのため、今回はモデルの定義のところではSigmoidの層を抜きにして、ロスに``nn.BCEWithLogitsLoss()`を使用します。

https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss

loss_module = nn.BCEWithLogitsLoss()

otimizzerの定義

様々な最適化手法がありますが、今回は確率的勾配降下法(Stochastic Gradient Descent; SGD)を使用します。勾配の更新レベルを決める学習率を指定する必要がありますが、今回の小さなネットワークには0.1を設定します。

optimizerは.step().zero_grad()という関数を持ちます。.step()関数は計算された勾配情報を元にパラメータを更新します。.zero_grad()関数はすべてのパラメータの勾配情報を0にします。これがなぜ必要なのかというと、勾配情報は計算されるたびに前回計算された値に加算されていくためです。そのため、.backward()の前に必ず.zero_grad()を行う必要があります。

# ここで、modelのパラメータを入力します
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

モデルをGPUへ移す

データをGPUにPushします。ここで、モデルは1回きり行えば大丈夫です。

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)
Device cpu
model.to(device)
SimpleClassifier(
  (list_layer): ModuleList(
    (0): Linear(in_features=2, out_features=4, bias=True)
    (1): Tanh()
    (2): Linear(in_features=4, out_features=1, bias=True)
  )
)

4-9. 学習

学習では以下のSTEPを行います。

  1. DataLoaderからバッチの取得
  2. バッチをモデルへ入力し予測値を得る
  3. 予測値と実測値からLossを計算する
  4. ロスに関して各パラメータの勾配を計算する
  5. 勾配方向へモデルのパラメータを更新する
  6. 3~7の操作をイタレーションの回数行う

訓練中はmodel.train()によりモデルをtrainのmodeにします。これは、dropoutBatchNormなどが学習時と推論時で挙動が異なるためです。推論時はmodel.eval()を実行します。

def train_model(model, optimizer, loss_module, data_loader, num_epochs=100):
    # modelをtrain modeに設定する
    model.train()

    for epoch in tqdm(range(num_epochs)):
        # 4. ここからbatch処理が走る
        for data_inputs, data_labels in data_loader:
            
            # dataをdeviceに移す
            data_inputs = data_inputs.to(device)
            data_labels = data_labels.to(device)

            # 5. バッチをモデルへ入力し予測値を得る 
            preds = model(data_inputs)
            # [Batch size, 1] -> [Batch size]へ変換
            preds = preds.squeeze(dim=1) 

            # 6. Lossを計算する
            loss = loss_module(preds, data_labels.float())

            # 7. 勾配を計算する
            # 必ず勾配計算前にzero_grad()を実行する
            optimizer.zero_grad() 
            loss.backward()

            ## 8. 勾配情報をもとにパラメータを更新する
            optimizer.step()

train_model(model, optimizer, loss_module, data_loader, 100)
100%|██████████| 100/100 [00:00<00:00, 158.24it/s]

モデルの保存と読み込み

model.state_dict()に学習可能なパラメータを取得できます。これを用いて、torch.save()により保存します。
また、torch.load()によりパラメータを読み込むことができます。

# 学習可能なパラメータを取得
state_dict = model.state_dict()
print(state_dict)

# 保存
torch.save(state_dict, "our_model.tar")
OrderedDict([('list_layer.0.weight', tensor([[ 2.5581,  2.6572],
        [ 2.1702, -1.7598],
        [ 2.0181, -2.4040],
        [ 1.8723,  1.7103]])), ('list_layer.0.bias', tensor([-1.0644,  0.8969, -1.0651, -2.7404])), ('list_layer.2.weight', tensor([[ 3.7602, -2.7670,  2.9931, -3.2199]])), ('list_layer.2.bias', tensor([-0.6014]))])
# 上記で保存したファイルを取得
state_dict = torch.load("our_model.tar")

# modelを作成し、パラメータを読み込む
new_model = SimpleClassifier(num_in=2, num_hid=4, num_out=1)
new_model.load_state_dict(state_dict)
<All keys matched successfully>

推論

推論時は、計算グラフ/勾配を作成/計算する必要もないため、メモリの削減と速度アップのために勾配計算をしません。計算グラフを構築しないためには、with torch.no_grad():を使用します。

test_dataset = XORDataset(size=500)
# 推論時なので、shuffle=False, drop_last=Falseを指定しておきます
test_data_loader = data.DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False) 
def eval_model(model, data_loader):
    # 推論モードに設定
    model.eval()

    true_preds, num_preds = 0., 0.
    # これ以降は計算グラフを作成しない
    with torch.no_grad():
        for data_inputs, data_labels in data_loader:
            # 推論
            data_inputs, data_labels = data_inputs.to(device), data_labels.to(device)
            preds = model(data_inputs)
            preds = preds.squeeze(dim=1)
            # sigmoidにより0~1の間に変換する
            preds = torch.sigmoid(preds)
            # 0か1へ変換する
            pred_labels = (preds >= 0.5).long()
            
            # Accuracyの計算
            true_preds += (pred_labels == data_labels).sum()
            num_preds += data_labels.shape[0]
            
    acc = true_preds / num_preds
    print(f"Accuracy of the model: {100.0*acc:4.2f}%")
# 今回は非常に簡単なデータなので、Acc.=100になっています
eval_model(model, test_data_loader)
Accuracy of the model: 100.00%

以上で例題を用いたpytorchの基礎は終わりです。
次回以降、各パートの詳細を見ていく予定です。

Discussion