Chapter 06

もっと MNIST で高い精度を出す

ろぐみ
ろぐみ
2023.02.05に更新

高みを目指せ

前の章で MNIST を解いてみました。
私の環境ではテストデータに対して 97% ほどの精度が出ました。

・・・もっと出したい、もっと力が欲しい、そう思った方はいませんか?

今回は、MNIST を解くモデルをもっと高い精度で解くために、いくつか工夫をしてみます。

また、少しリッチなモデルになるので今回から GPU を使うことにします。
Colab の編集 -> ノートブックの設定 -> ハードウェアアクセラレータ -> GPU を選択してください。

また、以下を実行して GPU を使う準備をしておきましょう。

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

データを読み込む部分は前回と同じです。

コードを見る
import torch
from torchvision import datasets, transforms

# バッチサイズ
BATCH_SIZE = 64

# 画像データの変換方法を指定
transform = transforms.Compose([
    transforms.ToTensor(),        # テンソルに変換 & 0-255 の値を 0-1 に変換
])

# MNIST を取得
train_dataset = datasets.MNIST(
    root='./data',        # データを保存するディレクトリ
    train=True,           # 学習用データを取得
    download=True,        # データがない場合はダウンロードする
    transform=transform,  # 画像データの変換方法を指定
)

# テスト用のデータを取得
test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform,
)

# データローダーの作成
train_loader = torch.utils.data.DataLoader(
    train_dataset,          # データセット
    batch_size=BATCH_SIZE,  # バッチサイズを指定
    shuffle=True,           # シャッフルする
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

print(train_dataset[0][0].shape)

モデルを改造する

前回のモデルは単純に2次元のデータを1次元に変換して、線形層を通して出力していました。
2次元のものを1次元に圧縮してしまうと、本来のデータの構造が失われてしまいます。

本来、人間が画像を見るときは2次元的に見るはずです。

つまりあるピクセルがあったとき、その周りのピクセルとの関係性を見ているのですが、1次元に圧縮してしまうと、その関係性が失われてしまいます。

バラバラ事件でお亡くなりになったA

そこで、今回は2次元のままで、線形層を通すのではなく、畳み込み層を通してみます。
いわゆる CNN(Convolutional Neural Network)を作ってみましょう。

畳み込み層

畳み込み層は、画像の特徴を抽出するための層です。

例えば、画像の中に猫がいるとき、猫の特徴として、目や鼻、耳などがあります。
このような特徴を抽出するために、畳み込み層を使います。

畳み込み層ではフィルターのようなものを用意して、それをずらしながら画像にかけていきます。

プーリング層

畳み込み層を通すと、画像の特徴が抽出されます。
しかし、そのままでは、画像の大きさが大きくなってしまいます。
そこで、プーリング層を使って、画像の大きさを小さくします。
大きな画像の中から代表的なものを取り出すというイメージです。

モデルの構造

今回は、畳み込み層とプーリング層を使って、モデルを作ってみます。

import torch
import torch.nn as nn

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

class MNISTCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            # 畳み込み層
            # 1チャンネルを32チャンネルにする、3x3のフィルターを使う、1つずつずらす
            nn.Conv2d(1, 32, 3, 1),
            # 活性化関数
            nn.ReLU(),
            # プーリング層、2x2の領域から最大のものを1つ取り出す
            nn.MaxPool2d(2, 2),
            # Dropout
            nn.Dropout(0.1),
        )
        self.layer2 = nn.Sequential(
            # 畳み込み層
            # 32チャンネルを64チャンネルにする、3x3のフィルターを使う、1つずつずらす
            nn.Conv2d(32, 64, 3, 1),
            # 活性化関数
            nn.ReLU(),
            # プーリング層、2x2の領域から最大のものを1つ取り出す
            nn.MaxPool2d(2, 2),
            # Dropout
            nn.Dropout(0.1),
        )
        self.layer3 = nn.Sequential(
            # (チャンネル数 x 縦 x 横)を1次元に変換する
            nn.Flatten(),
            # 線形層
            nn.Linear(64*5*5, 256),
            # 活性化関数
            nn.ReLU(),
            # 線形層
            nn.Linear(256, 10),
            # 出力層
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

結構難しくなりましたね。

nn.Conv2d は、畳み込み層を作るための関数です。
第1引数は、入力のチャンネル数、第2引数は、出力のチャンネル数、第3引数は、フィルターのサイズ、第4引数は、フィルターをずらす量になります。
今回は、1チャンネルの画像を32チャンネルに変換するため、第1引数は1、第2引数は32になります。
フィルターのサイズは3x3なので、第3引数は3になります。
フィルターをずらす量は1なので、第4引数は1になります。

例えば最初に入力された画像を Conv2d に通したときは、3x3 のフィルターを1つずつずらしながら進むので、出力のサイズは、(32, 28-3+1, 28-3+1) = (32, 26, 26) になります。

nn.MaxPool2d は、プーリング層を作るための関数です。
第1引数は、領域のサイズ、第2引数は、領域をずらす量になります。
今回はどちらも2なので、2x2の領域から最大のものを取り出していく作業をひたすら繰り返していきます。

例えば最初に MaxPool2d を通るとき、出力のサイズは (32, 26/2, 26/2) = (32, 13, 13) になります。

nn.Dropout は、Dropout 機構を作るための関数です。
Dropout 機構とは学習時のみ、一定の確率でニューロンを無効にする機構です。
これにより、過学習を防ぐ効果があると言われています。

nn.Flatten は、入力を1次元に変換するための関数です。

モデルの図も載せておきます。

各工程ごとにどのようにテンソルが変化するかも併記したので、実際に 3x3 のフィルターをずらしながらどのように畳み込みが行われるかがわかりやすいと思います。

とても複雑で頭良さそうな感じになってきましたね。

モデルの学習

モデルの学習は、前回とほぼ同じです。
ただし、今回は GPUを使うので、データをGPUにデータを乗せる必要があります。

lr も調整し、0.001 にしました。

from tqdm import tqdm

model = MNISTCNNModel().to(device)

criterion = nn.NLLLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    total_loss = 0
    for images, labels in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(images.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    # loss は平均を取って表示する
    print(f'Epoch: {epoch + 1}, Loss: {total_loss / len(train_loader)}')

.to(device) で、GPU にデータを乗せることができます。

モデルの評価

モデルの評価も、前回とほぼ同じです。

correct = 0
total = 0
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images.to(device))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f"Accuracy: {100 * correct / total}%")
Accuracy: 99.18%

99%を超える精度が出せました!

え?こんなに頑張ったのに、これくらいしか精度が上がらないの?と思うかもしれません。
そんなものです。ていうかそもそもここからの伸びしろってどんなに頑張って最強になったとしても、1%もないので仕方ないです。

というかよくよく考えると 10 クラス分類はクソ適当に予測すると10%になるわけですが、これが99%超えるのは結構すごくないですか?

  • エポック数を増やしてみる
    • もし、Loss が10エポック経ってもまだ減りそうな雰囲気があれば有効かもしれません
    • 過学習には注意してください!
    過学習について

    一般にモデルのパタメータが多いと、過学習しやすいです。

    過学習は英語で over-fitting といいます。より的確な和訳でいうと、過剰適合ということになります。
    パラメータが多いとその分だけ与えられたデータに対して過剰に適合してしまうのです。
    ある意味、過学習とはやや間違った日本語訳なのです。

    Linear で作ったモデルと比べると、今回のモデルはパラメータが多いので注意してください。

  • バッチサイズを変えてみる
    • バッチサイズを大きくすると、学習が早くなります
    • ただし、メモリが足りなくなる可能性があるので注意してください
  • 学習率を変えてみる
    • 結構難しいです。
  • 学習率を学習途中で動的に変えてみる
    • スケジューラーのようなものを使って学習率を下げていく手法が有名です
    • こちらも結構難しいです。
  • 最適化手法を変えてみる
    • Adam 以外にも色々試してみると良いかもしれません
  • モデルの構造を変えてみる
    • 余力があったらチャレンジしてみてください
    • 例えば
      • 畳み込み層を増やしたり
      • 畳み込み層のフィルターのサイズを変えたり
      • 畳み込み層のストライドを変えたり
      • プーリング層のサイズを変えたり
      • プーリング層のストライドを変えたり
      • Dropout の確率を変えたり
      • 活性化関数を変えたり
      • などなど、いろいろ試してみてください

参考までに、現在のトップは Accuracy: 99.91% だそうです。
https://paperswithcode.com/sota/image-classification-on-mnist?metric=Accuracy