👏

【やってみた】PyTorchでMNIST(モデル構築編)

2021/10/13に公開

はじめに

みんな大好きMNIST、きっとあなたもやってるはず!(モンハンのイャンクックレベルですね)

私もKeras全盛期時代はKerasで実装して遊んだことはあったのですが、PyTorchに移動してからMNISTで遊んでないなーと思い、今回はMNISTで遊んでみることにしました。

環境

WindowsにDockerを入れて動かしてます。すごいですね、Windowsでニューラルネットワークが動かせるのは最近まで知りませんでした・・・

気分的にUbuntuが必要だと思って、自分のゲーミングPC(Fallout76する用)にデュアルブートでインストールしようかと考えていたのですが、Windowsでできてとてもにっこりです

その他は

  • windows11
  • GeForce RTX2060 super
  • intel core i5 8500

みたいな感じです。

そもそもGPU使えてる?

Dockerで簡単にインストールした手前、GPUが正しく認識できているのかよくわかってません。
とりあえず確認してみます

import torch
print(torch.cuda.is_available())
True

とりあえずGPUは使えているみたいですね。

データセットの読み出し

今回はMnistをするので、PyTorchのライブラリから利用します。

import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

transform = T.Compose([T.ToTensor()])
traindata = torchvision.datasets.MNIST(root='./data', train=True,download=True,transform=transform)
trainloader = DataLoader(traindata,batch_size = 64)

まずは手書き数字を拝みたいと思います

data,label = iter(trainloader).next()
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,4))

for i in range(10):
    ax = fig.add_subplot(2,5,i+1)
    ax.imshow(data[i,0])

懐かしの手書き数字が出てきましたね。(何度見ても左下の2が納得いかないのは私だけでしょうか・・・)

同様な流れでテスト画像もロードしていきます

test_data = torchvision.datasets.MNIST(root='./data', train=False,download=True,transform=transform)
test_loader = DataLoader(test_data,batch_size = 64)

これでデータ周りは準備オッケーですね

モデル構築

今回は有名どころではなく、我流でモデルを組んでみました。

import torch.nn as nn
class Custom_Model(nn.Module):
    def __init__(self):
        super(Custom_Model,self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(1,16,3,padding = 1),
            #nn.BatchNorm2d(16)
            nn.ReLU(),
            nn.Conv2d(16,16,3,padding = 1),
            #nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(16,32,3,padding = 1),
            #nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,32,3,padding = 1),
            #nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2,2))
        )
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(7*7*32 , 10)
        )
    def forward(self,x):
        x = self.feature(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

バッチ正規化がコメントアウトされているのは、とりあえずバッチ正規化がない場合どうなるか検証したかったからです。

学習部分の構築

ここにきて気づきました。
今あるデータセットはtrainとtestの二種類です。一般的に機械学習ではtrainとvalとtestが必要ですが、今回は存在しません。

trainから分割しようと思ったのですが、trainはすでにデータセット化してあり、どうしたものか・・・

困ったときのGoogle、調べてみるとなんでも出てきますね
こちらのページからtorch.utils.data.Subsetというものを発見。
この機能初めて知りました、こんな便利なものがあるなんて・・・

参考にしてこのようにしました

from torch.utils.data.dataset import Subset

trainvaldata = torchvision.datasets.MNIST(root='./data', train=True,download=True,transform=transform)
data_len = len(trainvaldata)

train_len = int(data_len*0.8)

train_dataset = Subset(trainvaldata , [i for i in range(0,train_len)])
val_dataset = Subset(trainvaldata , [i for i in range(train_len,data_len)])

train_loader = DataLoader(train_dataset,batch_size = 64)
val_loader = DataLoader(val_dataset,batch_size = 64)

続いて損失関数や最適化関数の定義です。
クラス分類なので、クロスエントロピーを利用します。
最適化関数は特にこだわりがないので、汎用性が高いAdamを採用。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

最後に学習部分の定義です。ここをいつも手書きしてるんですが、いい方法ないでしょうかね・・・

学習の経過を表示する関数

def show_score(epoch,max_epoch,itr,max_itr,loss,acc,is_val=False):
    print('\r{} EPOCH[{:03}/{:03}] ITR [{:04}/{:04}] LOSS:{:.05f} ACC:{:03f}'.format("VAL  " if is_val else "TRAIN",epoch,max_epoch,itr,max_itr,loss,acc*100),end = '')

損失の合計値を計算する関数

def cal_loss(n,batch,total_loss,loss):
    return (total_loss * n + loss*batch)/(n + batch), n + batch

正解数を計算する関数

def cal_acc(t,p):
    p_arg = torch.argmax(p,dim=1)
    return torch.sum(t == p_arg)

学習部分

EPOCH = 10

train_loss = []
train_acc = []
val_loss = []
val_acc = []

for e in range(EPOCH):
    train_total_loss = 0
    val_total_loss = 0
    train_total_acc = 0
    val_total_acc = 0
    counter = 0
    model.train()
    for n,(data,label) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output,label)
        loss.backward()
        optimizer.step()
        train_total_loss , counter = cal_loss(counter,data.shape[0],train_total_loss ,loss.detach().item())
        train_total_acc += cal_acc(label,output)
        show_score(e+1,EPOCH,n+1,len(train_loader),train_total_loss,train_total_acc/counter)
    train_loss.append(train_total_loss)
    train_acc.append(train_total_acc/counter)
    counter = 0
    model.eval()
    print()
    with torch.no_grad():
        for n,(data,label) in enumerate(val_loader):
            data = data.to(device)
            label = label.to(device)
            output = model(data)
            loss = criterion(output,label)
            val_total_loss , counter = cal_loss(counter,data.shape[0],val_total_loss ,loss.detach().item())
            val_total_acc += cal_acc(label,output)
            show_score(e+1,EPOCH,n+1,len(train_loader),val_total_loss,val_total_acc/counter,is_val = True)
    val_loss.append(val_total_loss)
    val_acc.append(val_total_acc/counter)
    print()

いざ学習!

今回は試しに10Epoch程度学習してみました。

TRAIN EPOCH[001/010] ITR [0750/0750] LOSS:0.24941 ACC:92.254166
VAL   EPOCH[001/010] ITR [0188/0188] LOSS:0.08449 ACC:97.375000
TRAIN EPOCH[002/010] ITR [0750/0750] LOSS:0.07097 ACC:97.810417
VAL   EPOCH[002/010] ITR [0188/0188] LOSS:0.06074 ACC:98.216667
TRAIN EPOCH[003/010] ITR [0750/0750] LOSS:0.04928 ACC:98.4562530
VAL   EPOCH[003/010] ITR [0188/0188] LOSS:0.05813 ACC:98.316666
TRAIN EPOCH[004/010] ITR [0750/0750] LOSS:0.03817 ACC:98.8125000
VAL   EPOCH[004/010] ITR [0188/0188] LOSS:0.05769 ACC:98.283333
TRAIN EPOCH[005/010] ITR [0750/0750] LOSS:0.02980 ACC:99.1020810
VAL   EPOCH[005/010] ITR [0188/0188] LOSS:0.05011 ACC:98.675003
TRAIN EPOCH[006/010] ITR [0750/0750] LOSS:0.02341 ACC:99.3125000
VAL   EPOCH[006/010] ITR [0188/0188] LOSS:0.05570 ACC:98.466667
TRAIN EPOCH[007/010] ITR [0750/0750] LOSS:0.02089 ACC:99.343750
VAL   EPOCH[007/010] ITR [0188/0188] LOSS:0.05452 ACC:98.5416640
TRAIN EPOCH[008/010] ITR [0750/0750] LOSS:0.01754 ACC:99.4437480
VAL   EPOCH[008/010] ITR [0188/0188] LOSS:0.04969 ACC:98.799995
TRAIN EPOCH[009/010] ITR [0750/0750] LOSS:0.01532 ACC:99.5104140
VAL   EPOCH[009/010] ITR [0188/0188] LOSS:0.05235 ACC:98.799995
TRAIN EPOCH[010/010] ITR [0750/0750] LOSS:0.01354 ACC:99.5499950
VAL   EPOCH[010/010] ITR [0188/0188] LOSS:0.04905 ACC:98.8999940

最後ちょっと精度が落ちてますね・・・
一番val_lossが低いところでセーブをしてないので、最後のモデルで精度評価をしてみます。

test_total_acc = 0
model.eval()
with torch.no_grad():
    for n,(data,label) in enumerate(test_loader):
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        test_total_acc += cal_acc(label,output)
print(f"test acc:{test_total_acc/len(test_data)*100}")
test acc:98.8699951171875

微妙ですね・・・
次回はグラフ描画の実装と、精度向上手法を試していきたいと思います

Discussion