🔖

PyTorchでMNIST手書き画像の分類

2022/07/22に公開

概要

PyTorchでMNIST画像データセットを使用した、手書き数字画像分類の実装方法の記事。

環境

  • Google Colaboratory

動画

https://youtu.be/VlddUi0XHjg

実装

ライブラリのインストール

#%matplotlib inline
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt

デバイスの取得

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

設定値

batch_size = 128
n_epoch = 30
lr = 0.0005

MNISTデータセットをロード

transform = transforms.Compose([
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, ), (0.5, ))
                                ])

train_dataset = dset.MNIST(root='./datasets/', download=True, train=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = dset.MNIST(root='./datasets/', download=True, train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

データセットを表示

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(train_loader)
images, labels = dataiter.next()

imshow(vutils.make_grid(images))
print(' '.join(f'{str(labels[j].item()):5s}' for j in range(batch_size)))

モデルの作成

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d()
        self.fc1 = nn.Linear(12 * 12 * 64, 128)
        self.dropout2 = nn.Dropout2d()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout1(x)
        x = x.view(-1, 12 * 12 * 64)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

net = Net()

Optiminerと損失関数を作成

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

トレーニングの開始

net.to(device)
net.train()

plt_loss = []
plt_epoch =[]

for epoch in range(n_epoch):
    epoch_loss = 0.0
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_loss += loss.item()
        if i % 100 == 99:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    epoch_loss = epoch_loss / i
    plt_loss.append(epoch_loss)
    plt_epoch.append(epoch)

print('Finished Training')

plt.plot(plt_epoch,plt_loss,label="loss")
plt.show()

モデルの保存

torch.save(net.state_dict(), 'model.pth')

評価

correct = 0
total = 0
net.to("cpu")
net.eval()
with torch.no_grad():
    for (images, labels) in test_loader:
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy: {:.2f} %'.format(100 * float(correct/total)))

参考

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

Discussion