【PyTorch】手書き文字データセットMNISTで画像分類してみた
はじめに
はじめに,本記事で筆者が使用している環境や必要なライブラリについてまとめます.
環境
PC | MacBook Pro (16-inch, 2019) |
---|---|
OS | Monterey |
CPU | 2.3 GHz 8コアIntel Core i9 |
メモリ | 16GB |
Python | 3.9 |
使用するライブラリ
本記事で用いるライブラリとバージョンをまとめますが,特に気にせず
pip install numpy opencv-python torch torchvision
で問題ないかと思います.
ライブラリ | バージョン |
---|---|
numpy | 1.21.2 |
opencv-python | 4.5.5.64 |
torch | 1.10.1 |
torchvision | 0.11.2 |
念の為ライブラリとバージョンも記しておきます.
MNIST画像のロード
早速MNIST
の画像を表示してみましょう.
DataLoaderの作成
まず,DataLoader
を作成しましょう.
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
# 画像にどの様な変形を加えるか
transform = transforms.Compose([
#Tensor型に
transforms.ToTensor()
])
dataset = MNIST(
# データセットのパス
root=root,
train=True,
# ダウンロードしている場合はdownload=False
download=True,
transform=transform
)
dataloader = DataLoader(
dataset=dataset,
batch_size=64,
# シャッフルしない場合はshuffle=Flase
shuffle=True,
# batch_sizeを固定
drop_last=True
)
transforms
について詳しく知りたい,自作したいという方はこちらの記事を参考にしてください↓
OpenCVで描画
最後にOpenCV
で描画します.
img, label = iter(dataloader).next()
img = np.array(img)[0][0]
cv2.imshow(f"{label.item()}", img)
cv2.waitKey(0)
上記コードを実行すると下記画像が出力されると思います.
コードまとめ
以下に本セクションで扱ったコードをまとめます.
import cv2
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from typing import List
# dataloaderを取得する関数
def get_dataloader(root: str, batch_size: int=64) -> DataLoader:
# 画像にどの様な変形を加えるか
transform = transforms.Compose([
# Tensor型に
transforms.ToTensor(),
# ランダムで回転させる
+ transforms.RandomRotation(degree=90)
])
dataset = MNIST(
# データセットのパス
root=root,
train=True,
# ダウンロードしている場合はdownload=False
download=True,
transform=transform
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
# シャッフルしない場合はshuffle=Flase
shuffle=True,
# batch_sizeを固定
drop_last=True
)
return iter(dataloader).next()
if __name__ == "__main__":
img, label = get_dataloader(root="data")
img = np.array(img)[0][0]
cv2.imshow(f"{label.item()}", img)
cv2.waitKey(0)
transforms
に新たにtransforms.RandomRotation
を加えてみました.結果も念の為貼っておきます.
モデルの作成
今回利用するモデルは,ディープラーニング界では有名なResNet
です.
from torchvision.models import resnet18
model = resnet18(pretrained=True)
上記の様にprint
してみるとmodel
の構造を知ることができます.
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
まず,output
の2行目 (ResNet.conv1
) を見てみましょう.chanel
数が3 (カラー画像用) となっていますが,MNIST
画像のchanel
数は1 (モノクロ画像) なので変更しなければなりません.
また,最終行 (ResNet.fc
) を見てみましょう.out_features=1000
となっています.
これは1000クラスの分類ということですが,実際の手書き文字は10クラスなので変更する必要があります.
なので,下記の様に最初の畳み込み層と最終の全結合層を変更しましょう.
from torch import nn
from torchvision.models import resnet18
def get_resnet(pretrained: bool=True, num_classes: int=10) -> nn.Module:
# ImageNetで事前学習済みの重みをロード
model = resnet18(pretrained=pretrained)
# ここで更新する部分の重みは初期化される
+ model.conv1 = nn.Conv2d(
+ in_channels=1,
+ out_channels=64,
+ kernel_size=model.conv1.kernel_size,
+ stride=model.conv1.stride,
+ padding=model.conv1.padding,
+ bias=False
+ )
+ model.fc = nn.Linear(
+ in_features=model.fc.in_features,
+ out_features=num_classes
+ )
return model
上記のコードでは+
のついている行を追記しています.
実施にこの関数を用いてモデルをprint
してみると変更されていることがわかります.
モデルを学習させてみよう
いよいよResNet
に学習をさせていきますが,その前にディレクトリ構造とコードを整理します.
ファイル構造
以下のディレクトリ構造を参考にして,自分のディレクトリを見直してください.同じ構造でなければ作り直すことをおすすめします.
.
├─ data
├─ dataset
│ └── mnist.py
├─ model
│ └── resnet.py
└─ main.py
事前に準備したコード
下記のmnist.py
とresnet.py
はこれまでのチャプターで記述してきたコードなので,コピーすれば問題ないかと思います.
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
def get_dataloader(root: str, batch_size: int=64):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomRotation(degrees=90),
])
dataset = MNIST(
root=root,
train=True,
download=True,
transform=transform
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
return dataloader
from torch import nn
from torchvision.models import resnet18
def get_resnet(pretrained: bool=True, num_classes: int=10) -> nn.Module:
model = resnet18(pretrained=pretrained)
model.conv1 = nn.Conv2d(
in_channels=1,
out_channels=64,
kernel_size=model.conv1.kernel_size,
stride=model.conv1.stride,
padding=model.conv1.padding,
bias=False
)
model.fc = nn.Linear(
in_features=model.fc.in_features,
out_features=num_classes
)
return model
モデルの学習
お待たせしました.やっと本題です.
これからはmain.py
を編集していきます.
import torch
from torch import nn, optim
from tqdm import tqdm
from dataset.mnist import get_dataloader
from model.resnet import get_resnet
def train(total_epoch: int=20):
dataloader = get_dataloader(root="data", batch_size=64)
model = get_resnet(pretrained=True)
# オプティマイザーの定義
+ optimizer = optim.SGD(
+ params=model.parameters(),
+ lr=1e-3
+ )
# スケジューラーの定義
+ scheduler = optim.lr_scheduler.OneCycleLR(
+ optimizer=optimizer,
+ max_lr=1e-3,
+ total_steps=len(dataloader),
+ )
# 損失関数の定義
+ criterion = nn.CrossEntropyLoss()
+ model.train()
for epoch in range(total_epoch):
accuracy, train_loss = 0.0, 0.0
# tqdmを用いるとプログレスバーの表示ができる
for images, labels in tqdm(dataloader):
+ optimizer.zero_grad()
# モデルからの出力
+ out = model(images)
# lossの算出
+ loss = criterion(out, labels)
+ loss.backward()
+ optimizer.step()
# 推測値
preds = out.argmax(axis=1)
train_loss += loss.item()
# 正答率の算出
accuracy += torch.sum(preds == labels).item() / len(labels)
+ scheduler.step()
# 値の出力
print(f"epoch: {epoch + 1}")
print(f"loss: {train_loss / len(dataloader)}")
print(f"accuracy: {accuracy / len(dataloader)}")
if __name__ == "__main__":
train()
上記のコードの内+
の部分がモデルを学習させるために必要なコードです.また,+
でない部分を用いて,モデルを評価するために必要な出力を取得することができます.
出力結果
以下に出力結果を載せます.新たにtqdm
を追記することでプログレスバーを表示させることができます.
100%|██████████████████████████████████████| 937/937 [06:43<00:00, 2.32it/s]
loss: 1.5931195519395418
accuracy: 0.43836712913553894
100%|██████████████████████████████████████| 937/937 [06:56<00:00, 2.25it/s]
loss: 0.9492611091190367
accuracy: 0.6811299359658485
100%|██████████████████████████████████████| 937/937 [06:49<00:00, 2.29it/s]
loss: 0.6566917853650568
accuracy: 0.7852354589114194
100%|██████████████████████████████████████| 937/937 [07:02<00:00, 2.22it/s]
loss: 0.49791804670078904
accuracy: 0.8412820170757738
100%|██████████████████████████████████████| 937/937 [06:53<00:00, 2.27it/s]
loss: 0.4067016332832701
accuracy: 0.8705809765208111
利点がいくつかありますが,本記事では2点ほど紹介します.
- 後どれくらいの時間で出力を得ることができるかわかる
- 途中の出力も表示させることができる (アレンジする必要あり)
おわりに
お疲れ様でした!そして,長い記事を最後まで読んで頂きありがとうございました!
最後にこちらにより詳しい内容を記しているので是非ご覧ください↓
Discussion