🔖
PyTorchでMNIST手書き画像の分類
概要
PyTorchでMNIST画像データセットを使用した、手書き数字画像分類の実装方法の記事。
環境
- Google Colaboratory
動画
実装
ライブラリのインストール
#%matplotlib inline
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
デバイスの取得
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
設定値
batch_size = 128
n_epoch = 30
lr = 0.0005
MNISTデータセットをロード
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
])
train_dataset = dset.MNIST(root='./datasets/', download=True, train=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = dset.MNIST(root='./datasets/', download=True, train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
データセットを表示
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(vutils.make_grid(images))
print(' '.join(f'{str(labels[j].item()):5s}' for j in range(batch_size)))
モデルの作成
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3)
self.conv2 = nn.Conv2d(32, 64, 3)
self.pool = nn.MaxPool2d(2, 2)
self.dropout1 = nn.Dropout2d()
self.fc1 = nn.Linear(12 * 12 * 64, 128)
self.dropout2 = nn.Dropout2d()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(F.relu(self.conv2(x)))
x = self.dropout1(x)
x = x.view(-1, 12 * 12 * 64)
x = F.relu(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
return x
net = Net()
Optiminerと損失関数を作成
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
トレーニングの開始
net.to(device)
net.train()
plt_loss = []
plt_epoch =[]
for epoch in range(n_epoch):
epoch_loss = 0.0
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss += loss.item()
if i % 100 == 99:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
epoch_loss = epoch_loss / i
plt_loss.append(epoch_loss)
plt_epoch.append(epoch)
print('Finished Training')
plt.plot(plt_epoch,plt_loss,label="loss")
plt.show()
モデルの保存
torch.save(net.state_dict(), 'model.pth')
評価
correct = 0
total = 0
net.to("cpu")
net.eval()
with torch.no_grad():
for (images, labels) in test_loader:
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy: {:.2f} %'.format(100 * float(correct/total)))
参考
Discussion