👌

Intro2DL : 画像分類の代表的なアーキテクチャ

2022/10/01に公開

画像分類のアーキテクチャ

現在でもほとんどの最先端のアーキテクチャの基礎となっている以下のCNNアーキテクチャを見ていきます。

  • AlexNet
  • VGG
  • GoogleNet
  • ResNet
  • DenseNet

今回からモデルが重くなるため、GoogleColabで実行します。
(前回まではMPSを使用していたのですが、今回のPytorch Lightningのコード内で解決できないバグがあるようですので、ColabのCuda環境で実行しています)
以下のボタンをクリックすると、GoogleColabに勝手に移動します。
リンクは、https://github.com/https://colab.research.google.com/githubに変更したURLにするだけです。

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yoshida-chem/Intro2DL/blob/main/docs/07_image_arch/img_arch.ipynb)

Open In Colab

import os
import json
import math
try:
    import japanize_matplotlib
except ModuleNotFoundError:
    !pip install japanize_matplotlib
    import japanize_matplotlib
import numpy as np 
import time
import copy
import requests
from PIL import Image
from types import SimpleNamespace
from io import StringIO

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
sns.set()

from tqdm import tqdm

from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        # GPUありの場合
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    # PytorchLightningでエラーが出るので、MPSはパス
    #elif torch.backends.mps.is_built():
    #    device = torch.device("mps:0")
    else:
        device = torch.device("cpu")
    return device
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting japanize_matplotlib
  Downloading japanize-matplotlib-1.1.3.tar.gz (4.1 MB)
[K     |████████████████████████████████| 4.1 MB 35.4 MB/s 
[?25hRequirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from japanize_matplotlib) (3.2.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->japanize_matplotlib) (3.0.9)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->japanize_matplotlib) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->japanize_matplotlib) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->japanize_matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.7/dist-packages (from matplotlib->japanize_matplotlib) (1.21.6)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->japanize_matplotlib) (4.1.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->japanize_matplotlib) (1.15.0)
Building wheels for collected packages: japanize-matplotlib
  Building wheel for japanize-matplotlib (setup.py) ... [?25l[?25hdone
  Created wheel for japanize-matplotlib: filename=japanize_matplotlib-1.1.3-py3-none-any.whl size=4120275 sha256=73eada6ddc2d4adb5db525b67916b52be62d2e71fef90cb71227a13fcaa97ccf
  Stored in directory: /root/.cache/pip/wheels/83/97/6b/e9e0cde099cc40f972b8dd23367308f7705ae06cd6d4714658
Successfully built japanize-matplotlib
Installing collected packages: japanize-matplotlib
Successfully installed japanize-matplotlib-1.1.3
from google.colab import drive

drive.mount('/content/gdrive')
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
# pytorchでデータをダウンロードするときのパス(重複してDWしないため)
# モデルの保存先
if os.path.isdir('/content/gdrive/'):
  DATASET_PATH = "/content/gdrive/MyDrive/data"
  CHECKPOINT_PATH = "/content/gdrive/MyDrive/models/07_image_arch"
else:
  DATASET_PATH = "../data"
  CHECKPOINT_PATH = "../models/07_image_arch"
os.makedirs(DATASET_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# 再現性のためにseedを固定する
set_seed(42)

# device情報を取得する
device = get_device()
print(f"deviceは{device}です")
deviceはcuda:0です

データ準備

今回は、CIFAR10のデータセットを使用します。。

# 正規化のための統計値を計算
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0,1,2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0,1,2))
print("Data mean", DATA_MEANS)
print("Data std", DATA_STD)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/gdrive/MyDrive/data/cifar-10-python.tar.gz



  0%|          | 0/170498071 [00:00<?, ?it/s]


Extracting /content/gdrive/MyDrive/data/cifar-10-python.tar.gz to /content/gdrive/MyDrive/data
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]
# 正規化+FlipとResizeCropのデータ拡張
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(DATA_MEANS, DATA_STD)
                                     ])
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(DATA_MEANS, DATA_STD)
                                     ])

# datasetの準備
# validationではデータ拡張なし
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# 訓練データとバリデーションデータに分割(ここで、同じ分割&バリデーション時はデータ拡張なし)
set_seed(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
set_seed(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

# dataloaderの準備
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=2)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=2)
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
# チャネルごとに正規化できているか確認
imgs, _ = next(iter(train_loader))

# BatchSize, Channel, W, Hの形状を持つ
print(imgs.shape)
print("Batch mean", imgs.mean(dim=[0,2,3]))
print("Batch std", imgs.std(dim=[0,2,3]))
torch.Size([128, 3, 32, 32])
Batch mean tensor([0.0231, 0.0006, 0.0005])
Batch std tensor([0.9865, 0.9849, 0.9868])
NUM_IMAGES = 4
images = [train_dataset[idx][0] for idx in range(NUM_IMAGES)]
# transformする場合は、Image Typeに変換する必要あり
orig_images = [Image.fromarray(train_dataset.data[idx]) for idx in range(NUM_IMAGES)]
orig_images = [train_transform(img) for img in orig_images]

img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8,8))
plt.title("Augmentation examples on CIFAR10")
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

アーキテクチャの実装

準備として活性化関数がまとまった辞書を作っておきます。これは後述のPytorch Lightningでハイパーパラメータを保存したいためです。(Objectを引数にすると保存されない)

act_fn_by_name = {
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
    "leakyrelu": nn.LeakyReLU,
    "gelu": nn.GELU
}

AlexNet

AlexNetは、画像認識のコンテストILSVRC 2012で優勝したモデルです。AlexNetは5つの畳み込み層と3つのプーリング層、2つのコントラスト正規化、3つの全結合層により構成されています。学習では、fc6とfc7の層のユニットにドロップアウト(p=0.5)が用いられています。また、各CNN層とFC層の後にはReLUが使用されています。また、Local Response Normを第一と第二のCNN層の後のReLUの後に適応しています。これは汎化性能を上げたようですが、現在のBatch Normalizationが開発前なためこちらが使われているようです。

また、元論文では当時のGPUの性能からブロックを2つに分けているので、今回は1つにまとめます。
詳しくは、元論文の3.5 Overall Architectureを参考にしてください。

airplane

https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf

class AlexNet(nn.Module):

    def __init__(self, num_classes, act_fn_name="relu"):
        super().__init__()

        # self.hparams.name でアクセスしたいので、SimpleNamespaceを使用
        self.hparams = SimpleNamespace(num_classes=num_classes,
                                       act_fn_name=act_fn_name,
                                       act_fn=act_fn_by_name[act_fn_name])
        self._create_network()
        self._init_params()
    
    def _create_network(self):
        self.features = nn.Sequential(
            # 1st layer
            ## channel48*2をまとめてchannel=96にしている。また、WとHのサイズを合わせるためにpadding=2へ変更
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2),
            self.hparams.act_fn(),
            ## ハイパーパラメータは元論文中でバリデーションを元に決められています
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 2nd layer
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2),
            self.hparams.act_fn(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 3-5rd layer
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
            self.hparams.act_fn(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=256*6*6, out_features=4096),
            self.hparams.act_fn(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            self.hparams.act_fn(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=self.hparams.num_classes),
        )

    def _init_params(self):
        # 元論文とは異なる初期化
        # 活性化関数に基づきConv2dの初期化を行う
        # ResNetではmode="fan_out"出力側で計算を選択されているが、ここではデフォルトのfan_inを選択
        # https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L156 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity=self.hparams.act_fn_name)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        #print(x.shape)
        x = self.classifier(x)

        return x
# 実装中のサイズの確認
# N * C * W * H
tmp = torch.rand(1, 3, 224, 224)
print(tmp.shape)

cnn = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2)
tmp_ = cnn(tmp)
print(tmp_.shape)

net = AlexNet(num_classes=10, act_fn_name="relu")
tmp_ = net(tmp)
print(tmp_.shape)
torch.Size([1, 3, 224, 224])
torch.Size([1, 96, 55, 55])
torch.Size([1, 10])

VGG

VGGは、ILSVRC2014で提案されたモデルで、16層または19層からなるCNNを使用したモデルになります。
VGGNetはAlexNetを大規模にしたもので、3*3のフィルタを使用することで活性化関数の適用回数が増え表現力が増加しています。AlexNetと比較するとコントラスト正規化層がなくなっています。

欠点としては、Global Average Pooling層がないため、全結合層の箇所のパラメータが多くなり計算が重い点です。
今回は、下記画像のDであるVGG16を実装しました。VGG19は後半3つのCNN層が1つずつ多くなる点が異なります。

vggnet

https://arxiv.org/pdf/1409.1556.pdf

class VGG16Net(nn.Module):

    def __init__(self, num_classes, act_fn_name="relu"):
        super().__init__()

        # self.hparams.name でアクセスしたいので、SimpleNamespaceを使用
        self.hparams = SimpleNamespace(num_classes=num_classes,
                                       act_fn_name=act_fn_name,
                                       act_fn=act_fn_by_name[act_fn_name])
        self._create_network()
        self._init_params()
    
    def _create_network(self):
        self.features = nn.Sequential(
            # 1
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            # 4
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            # 5
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            self.hparams.act_fn(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=512*7*7, out_features=4096),
            self.hparams.act_fn(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            self.hparams.act_fn(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=self.hparams.num_classes),
        )

    def _init_params(self):
        # 元論文とは異なる初期化
        # 活性化関数に基づきConv2dの初期化を行う
        # ResNetではmode="fan_out"出力側で計算を選択されているが、ここではデフォルトのfan_inを選択
        # https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L156 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity=self.hparams.act_fn_name)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x
# 実装中のサイズの確認
# N * C * W * H
tmp = torch.rand(1, 3, 224, 224)
print(tmp.shape)

net = VGG16Net(num_classes=10, act_fn_name="relu")
tmp_ = net(tmp)
print(tmp_.shape)
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])

GoogleNet

VGGがILSVRC2014で二位を獲得した際の優勝したモデルであるGoogleNetを実装します。
Inception Blockとglobal average pooling(GAP)が特徴的なアーキテクチャです。

Inception Block

Inceptionブロックは、同じ特徴マップに対して、1x1、3x3、5x5のCNNと、Max Poolingという4つのブロックを別々に適用します。これにより、ネットワークは同じデータを異なる受容野で見ることができます。また、実装では1x1のCNN(ボトルネック層やPointwise Convolutionとも呼ばれているCNN層)を挟むことで、まずチャンネル方向の畳込みを行った後に縦横方向の畳込みを行うことができ、分割して計算をすることになるため演算量を削減できています。もちろん、5x5の畳み込みだけを学習する方が理論的には強力ですが、これはより計算とメモリが重いだけでなく、オーバーフィットしやすくなる傾向があります。

inception

全体のインセプションブロックは以下のように実装できます。

class InceptionBlock(nn.Module):
    
    def __init__(self, c_in, c_red : dict, c_out : dict, act_fn):
        """
        Inputs:
            c_in - Number of input feature maps from the previous layers
            c_red - Dictionary with keys "3x3" and "5x5" specifying the output of the dimensionality reducing 1x1 convolutions
            c_out - Dictionary with keys "1x1", "3x3", "5x5", and "max"
            act_fn - Activation class constructor (e.g. nn.ReLU)
        """
        super().__init__()
        
        # 1x1 convolution branch
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(c_in, c_out["1x1"], kernel_size=1),
            nn.BatchNorm2d(c_out["1x1"]),
            act_fn()
        )
        
        # 3x3 convolution branch
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(c_in, c_red["3x3"], kernel_size=1),
            nn.BatchNorm2d(c_red["3x3"]),
            act_fn(),
            nn.Conv2d(c_red["3x3"], c_out["3x3"], kernel_size=3, padding=1),
            nn.BatchNorm2d(c_out["3x3"]),
            act_fn()
        )
        
        # 5x5 convolution branch
        self.conv_5x5 = nn.Sequential(
            nn.Conv2d(c_in, c_red["5x5"], kernel_size=1),
            nn.BatchNorm2d(c_red["5x5"]),
            act_fn(),
            nn.Conv2d(c_red["5x5"], c_out["5x5"], kernel_size=5, padding=2),
            nn.BatchNorm2d(c_out["5x5"]),
            act_fn()
        )
        
        # Max-pool branch
        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(c_in, c_out["max"], kernel_size=1),
            nn.BatchNorm2d(c_out["max"]),
            act_fn()
        )

    def forward(self, x):
        x_1x1 = self.conv_1x1(x)
        x_3x3 = self.conv_3x3(x)
        x_5x5 = self.conv_5x5(x)
        x_max = self.max_pool(x)
        # 連結
        x_out = torch.cat([x_1x1, x_3x3, x_5x5, x_max], dim=1)
        return x_out

上記のInception Blockを用いてGoogleNetを実装します。元論文では以下のようなかなり大きなモデルですが、今回はMNISTの簡単な問題を解きたいだけなので縮小して実装されています。また、Batch Normalizationはこの後のResNetで提案されたものですが、今回の実装ではGoogleNetにも使用されている点も注意してください。

Global Average Pooling(GAP)

GAPは、各チャンネルの値の平均を算出することでパラメータ数を削減します。各チャネルごとの平均により、過学習を抑制することが期待できます。また、出力されたニューロンの数がチャネル数飲みに依存するので、入力の画像サイズに依存しなくなるという特徴もあります。
ResNetでも使用されている方法ですが、Pytorchで直接実装されていないのでnn.AdaptiveAvgPool2d((1,1))として使用する必要があります。https://www.テクめも.com/entry/pytorch-pooling

google

class GoogleNet(nn.Module):

    def __init__(self, num_classes=10, act_fn_name="relu", **kwargs):
        super().__init__()
        self.hparams = SimpleNamespace(num_classes=num_classes,
                                       act_fn_name=act_fn_name,
                                       act_fn=act_fn_by_name[act_fn_name])
        self._create_network()
        self._init_params()

    def _create_network(self):
        # チャネルサイズを上げる
        self.input_net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            self.hparams.act_fn()
        )
        # Inception Blocks
        self.inception_blocks = nn.Sequential(
            InceptionBlock(64, c_red={"3x3": 32, "5x5": 16}, c_out={"1x1": 16, "3x3": 32, "5x5": 8, "max": 8}, act_fn=self.hparams.act_fn),
            InceptionBlock(64, c_red={"3x3": 32, "5x5": 16}, c_out={"1x1": 24, "3x3": 48, "5x5": 12, "max": 12}, act_fn=self.hparams.act_fn),
            nn.MaxPool2d(3, stride=2, padding=1),  # 32x32 => 16x16
            InceptionBlock(96, c_red={"3x3": 32, "5x5": 16}, c_out={"1x1": 24, "3x3": 48, "5x5": 12, "max": 12}, act_fn=self.hparams.act_fn),
            InceptionBlock(96, c_red={"3x3": 32, "5x5": 16}, c_out={"1x1": 16, "3x3": 48, "5x5": 16, "max": 16}, act_fn=self.hparams.act_fn),
            InceptionBlock(96, c_red={"3x3": 32, "5x5": 16}, c_out={"1x1": 16, "3x3": 48, "5x5": 16, "max": 16}, act_fn=self.hparams.act_fn),
            InceptionBlock(96, c_red={"3x3": 32, "5x5": 16}, c_out={"1x1": 32, "3x3": 48, "5x5": 24, "max": 24}, act_fn=self.hparams.act_fn),
            nn.MaxPool2d(3, stride=2, padding=1),  # 16x16 => 8x8
            InceptionBlock(128, c_red={"3x3": 48, "5x5": 16}, c_out={"1x1": 32, "3x3": 64, "5x5": 16, "max": 16}, act_fn=self.hparams.act_fn),
            InceptionBlock(128, c_red={"3x3": 48, "5x5": 16}, c_out={"1x1": 32, "3x3": 64, "5x5": 16, "max": 16}, act_fn=self.hparams.act_fn)
        )
        # 出力
        self.output_net = nn.Sequential(
            # Linearで繋げる場合はC*W*Hのサイズが必要だったが、GAPがあることでチャネル数C=128のみになり計算量削減
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, self.hparams.num_classes)
        )

    def _init_params(self):
        # 活性化関数に基づき初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.inception_blocks(x)
        x = self.output_net(x)
        return x
# 実装中のサイズの確認
# N * C * W * H
tmp = torch.rand(1, 3, 224, 224)
print(tmp.shape)

net = GoogleNet(num_classes=10, act_fn_name="relu")
tmp_ = net(tmp)
print(tmp_.shape)
torch.Size([1, 3, 224, 224])
torch.Size([1, 10])

ResNet

残差接続

残差接続では、x_{l+1} = F(x_l)をモデル化する代わりに、x_{l+1} = x_l + F(x_l)をモデル化します。ここで、F()は非線形関数です。このような残差接続に対して勾配を計算すると、以下のようになりFの影響を受けにくく安定した勾配伝播が可能になります。実装上の注意点としては、x_lF(x_l)を足し算する必要があるので、F()で解像度を下げた場合は特徴マップの形状が同じになるように変換が必要である点です。

\partial x_{l+1}/\partial x_l = I + \partial F(x_l)/\partial x_l

ResNetにはいくつかのバージョンが提案されているので、ここでは3つ見ていきます。

  • Original ResNet block:スキップコネクションの後にReLUを適用する。元論文の実装。
  • Pre-Activation ResNet block:Fの初めの段階でReLUを適用する。より深いネットワークでは上記のように勾配流が恒等行列を持つことを保証できるため、こちらの方が良いことが知られています。(下図; https://arxiv.org/abs/1603.05027)
  • BottleNeckを持つResNet block:計算量を増やすことなくブロック内のチャネル数を増やす狙い。Inception Block同様に計算量を削減するために1x1のCNNが使用されている。

res

class OriginalResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, subsample=False, c_out=-1):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            subsample - Trueの場合、関数F内でstride=2にすることで解像度を下げる。Falseの場合そのまま。
            c_out - subsample=Trueの場合のみ、有効。Falseの場合、入力と同じ。
        """
        super().__init__()
        if not subsample:
            c_out = c_in
            
        # 上記の説明における関数F
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=1 if not subsample else 2, bias=False),  # No bias needed as the Batch Norm handles it
            nn.BatchNorm2d(c_out),
            act_fn(),
            nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(c_out)
        )
        
        # 1x1 convolution with stride 2 により、元の入力のサイズを下げる 
        self.downsample = nn.Conv2d(c_in, c_out, kernel_size=1, stride=2) if subsample else None
        self.act_fn = act_fn()

    def forward(self, x):
        # F(x)
        z = self.net(x)
        # x : stride=2で解像度下がる場合、形状を同じにする
        if self.downsample is not None:
            x = self.downsample(x)
        # skip connection
        out = x + z
        # skip connectionの後にReLUがくる
        out = self.act_fn(out)
        return out
class PreActResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, subsample=False, c_out=-1):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            subsample - Trueの場合、関数F内でstride=2にすることで解像度を下げる。Falseの場合そのまま。
            c_out - subsample=Trueの場合のみ、有効。Falseの場合、入力と同じ。
        """
        super().__init__()
        if not subsample:
            c_out = c_in
            
        # 上記の説明における関数F
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            # skip connection の最初の段階でReLUを適用
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=1 if not subsample else 2, bias=False),
            nn.BatchNorm2d(c_out),
            act_fn(),
            nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False)
        )
        
        # 1*1 CNNを適用する場合は非線形性を適用する必要あり 
        self.downsample = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=1, stride=2, bias=False)
        ) if subsample else None

    def forward(self, x):
        # F(x)
        z = self.net(x)
        # x : stride=2で解像度下がる場合、形状を同じにする
        if self.downsample is not None:
            x = self.downsample(x)
        # skip connection
        out = x + z
        return out
class BottleNeckResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, subsample=False, c_out=-1, **kwargs):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            subsample - Trueの場合、関数F内でstride=2にすることで解像度を下げる。Falseの場合そのまま。
            c_out - Number of output features.論文中では、c_out = c_in*4 とされています
            https://arxiv.org/pdf/1512.03385.pdf
        """
        super().__init__()
            
        # 上記の説明における関数F
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_in, kernel_size=1, padding=1, stride=1 if not subsample else 2, bias=False),  # No bias needed as the Batch Norm handles it
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_in, kernel_size=3, padding=1),  # No bias needed as the Batch Norm handles it
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=1, padding=1, bias=False),
            nn.BatchNorm2d(c_out)
        )
        
        # 1x1 convolution with stride 2 により、元の入力のサイズを下げる
        # また、BottleNeckタイプは最後のチャネル数が256になったりするので変更 
        self.downsample = nn.Conv2d(c_in, c_out, kernel_size=1, stride=1 if not subsample else 2)
        self.act_fn = act_fn()

    def forward(self, x):
        # F(x)
        z = self.net(x)
        # x : stride=2で解像度下がる場合、形状を同じにする
        if self.downsample is not None:
            x = self.downsample(x)
        # skip connection
        out = x + z
        # skip connectionの後にReLUがくる
        out = self.act_fn(out)
        return out
resnet_blocks_by_name = {
    "ResNetBlock": OriginalResNetBlock,
    "PreActResNetBlock": PreActResNetBlock,
    "BottleNeckResNetBlock": BottleNeckResNetBlock
}

ResNet全体のアーキテクチャは、複数のResNetブロックを積み重ねることで構成されており、そのうちのいくつかは入力をダウンサンプリングしています。ネットワーク全体のResNetブロックについて話すとき通常、同じ出力形状でグループ化します。つまり、ResNetが[3,3,3]のブロックを持つということは、3つのResNetブロックのグループを3回重ね、4番目と7番目のブロックでサブサンプリングが行われていることを意味しています。

class ResNet(nn.Module):

    def __init__(self, num_classes=10, num_blocks=[3,3,3], c_hidden=[16,32,64], act_fn_name="relu", block_name="ResNetBlock", **kwargs):
        """
        Inputs: 
            num_classes - Number of classification outputs (10 for CIFAR10)
            num_blocks - List with the number of ResNet blocks to use. The first block of each group uses downsampling, except the first.
            c_hidden - List with the hidden dimensionalities in the different blocks. Usually multiplied by 2 the deeper we go.
            act_fn_name - Name of the activation function to use, looked up in "act_fn_by_name"
            block_name - Name of the ResNet block, looked up in "resnet_blocks_by_name"
        """
        super().__init__()
        assert block_name in resnet_blocks_by_name
        self.hparams = SimpleNamespace(num_classes=num_classes, 
                                       c_hidden=c_hidden, 
                                       num_blocks=num_blocks, 
                                       act_fn_name=act_fn_name,
                                       act_fn=act_fn_by_name[act_fn_name],
                                       block_class=resnet_blocks_by_name[block_name])
        self._create_network()
        self._init_params()

    def _create_network(self):
        c_hidden = self.hparams.c_hidden
        
        # A first convolution on the original image to scale up the channel size
        if self.hparams.block_class == PreActResNetBlock: # => Don't apply non-linearity on output
            self.input_net = nn.Sequential(
                nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False)
            )
        else:
            self.input_net = nn.Sequential(
                nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(c_hidden[0]),
                self.hparams.act_fn()
            )
        
        # Creating the ResNet blocks
        blocks = []
        for block_idx, block_count in enumerate(self.hparams.num_blocks):
            for bc in range(block_count):
                subsample = (bc == 0 and block_idx > 0) # Subsample the first block of each group, except the very first one.
                blocks.append(
                    self.hparams.block_class(c_in=c_hidden[block_idx if not subsample else (block_idx-1)],
                                             act_fn=self.hparams.act_fn,
                                             subsample=subsample,
                                             c_out=c_hidden[block_idx])
                )
        self.blocks = nn.Sequential(*blocks)
        
        # Mapping to classification output
        self.output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(c_hidden[-1], self.hparams.num_classes)
        )

    def _init_params(self):
        # Based on our discussion in Tutorial 4, we should initialize the convolutions according to the activation function
        # Fan-out focuses on the gradient distribution, and is commonly used in ResNets
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.blocks(x)
        x = self.output_net(x)
        return x

DenseNet

DenseNetは、ResNetとは少し異なる方法でskip connectionを用いることで冗長な特徴マップを学習する必要性を排除しています。
ネットワークに深く入り込むと、モデルはパターンを認識するために抽象的な特徴を学習します。しかし、複雑なパターンの中には、抽象的な特徴(手、顔など)と低レベルの特徴(エッジ、基本色など)の組み合わせで構成されているものがあるため、このような低レベルの特徴を深い層で見つけるためには、再度学習しなければならず無駄になります。
DenseNetは、以下の図のように各畳み込みが以前のすべての入力特徴マップを少量のフィルターを増やす(増やすフィルターの数kは、成長率パラメータ(Growth rate)と呼ばれています)だけでに、特徴を再利用することができる効率的な方法になります。(https://arxiv.org/pdf/1608.06993.pdf)
また、一番最後の層はTransition Layerと呼ばれる層で、特徴マップの高さ、幅、チャンネルの大きさの次元を減らす役割を担っています。

denseb

実装としては、DenseBlockとDenseBlockを構成するDenseLayer、DenseBlockの一番最後の層であるTransition Layerで構成されます。

dense

class DenseLayer(nn.Module):
    
    def __init__(self, c_in, bn_size, growth_rate, act_fn):
        """
        Inputs:
            c_in - Number of input channels
            bn_size - Bottleneck size (factor of growth rate) for the output of the 1x1 convolution. Typically between 2 and 4.
            growth_rate - Number of output channels of the 3x3 convolution
            act_fn - Activation class constructor (e.g. nn.ReLU)
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, bn_size * growth_rate, kernel_size=1, bias=False),
            nn.BatchNorm2d(bn_size * growth_rate),
            act_fn(),
            nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        
    def forward(self, x):
        out = self.net(x)
        # 出力チャネルは入力のオリジナルと連結される
        out = torch.cat([out, x], dim=1)
        return out
class DenseBlock(nn.Module):
    
    def __init__(self, c_in, num_layers, bn_size, growth_rate, act_fn):
        """
        Inputs:
            c_in - Number of input channels
            num_layers - Number of dense layers to apply in the block
            bn_size - Bottleneck size to use in the dense layers
            growth_rate - Growth rate to use in the dense layers
            act_fn - Activation function to use in the dense layers
        """
        super().__init__()
        layers = []
        for layer_idx in range(num_layers):
            layers.append(
                # 以前の全ての層の特徴マップを連結した元の入力を入力とする
                # 出力は:c_out = c_in + layer_idx * growth_rate + growth_rateになる 
                DenseLayer(c_in=c_in + layer_idx * growth_rate, 
                           bn_size=bn_size,
                           growth_rate=growth_rate,
                           act_fn=act_fn)
            )
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.block(x)
        return out
class TransitionLayer(nn.Module):
    
    def __init__(self, c_in, c_out, act_fn):
        super().__init__()
        self.transition = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            # 1*1 CNNでチャネル方向の次元削減
            nn.Conv2d(c_in, c_out, kernel_size=1, bias=False),
            # 高さと幅を削減するために、カーネルサイズ2とストライド2の平均プーリングを適用する
            nn.AvgPool2d(kernel_size=2, stride=2) # Average the output for each 2x2 pixel group
        )
        
    def forward(self, x):
        return self.transition(x)

上記を使用してDenseNetを構築していきます。

class DenseNet(nn.Module):
    
    def __init__(self, num_classes=10, num_layers=[6,6,6,6], bn_size=2, growth_rate=16, act_fn_name="relu", **kwargs):
        super().__init__()
        self.hparams = SimpleNamespace(num_classes=num_classes,
                                       num_layers=num_layers,
                                       bn_size=bn_size,
                                       growth_rate=growth_rate,
                                       act_fn_name=act_fn_name,
                                       act_fn=act_fn_by_name[act_fn_name])
        self._create_network()
        self._init_params()
        
    def _create_network(self):
        c_hidden = self.hparams.growth_rate * self.hparams.bn_size # The start number of hidden channels
        
        # A first convolution on the original image to scale up the channel size
        self.input_net = nn.Sequential(
            nn.Conv2d(3, c_hidden, kernel_size=3, padding=1) # No batch norm or activation function as done inside the Dense layers
        )
        
        # Creating the dense blocks, eventually including transition layers
        blocks = []
        for block_idx, num_layers in enumerate(self.hparams.num_layers):
            blocks.append( 
                DenseBlock(c_in=c_hidden, 
                           num_layers=num_layers, 
                           bn_size=self.hparams.bn_size,
                           growth_rate=self.hparams.growth_rate,
                           act_fn=self.hparams.act_fn)
            )
            # denseblockの出力 = c_in + blockの層の数 * growth_rate
            c_hidden = c_hidden + num_layers * self.hparams.growth_rate 
            if block_idx < len(self.hparams.num_layers)-1: # Don't apply transition layer on last block
                blocks.append(
                    TransitionLayer(c_in=c_hidden,
                                    c_out=c_hidden // 2,
                                    act_fn=self.hparams.act_fn))
                c_hidden = c_hidden // 2
                
        self.blocks = nn.Sequential(*blocks)
        
        # Mapping to classification output
        self.output_net = nn.Sequential(
            nn.BatchNorm2d(c_hidden), # The features have not passed a non-linearity until here.
            self.hparams.act_fn(),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(c_hidden, self.hparams.num_classes)
        )

    def _init_params(self):
        # Based on our discussion in Tutorial 4, we should initialize the convolutions according to the activation function
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.blocks(x)
        x = self.output_net(x)
        return x

Pytorch Lightning

Pytorch Lightningを使用することで、細かいコードを書かなくても良くなるのでより本質的な作業に時間を費やすことができるようになります。
例えば、to_deviceを書かなくてもいいので、その分実装も楽です。

ここでは、pl.LightningModule(torch.nn.Moduleを継承)を継承します。このクラスは5つの主要なメソッドを持ちます。もしこれら以外のコードを変更する場合は、overwriteできる関数があるので、ドキュメントを参照します。

  • 初期化(init): 必要なパラメータやモデルを作成します
  • オプティマイザ(configure_optimizers): オプティマイザ、学習率スケジューラなどを作成します
  • トレーニングループ(training_step):単一バッチの損失計算を定義するだけ(optimizer.zero_grad(), loss.backward() および optimizer.step() のループ。ログ記録/保存操作はバックグラウンドで行われます)。
  • 検証ループ(validation_step) 訓練と同様に、ステップごとに何が起こるかを定義するだけです
  • テストループ(test_step): 検証と同じで、テストセットに対してのみ行われます

参考:https://qiita.com/ground0state/items/c1d705ca2ee329cdfae4
参考:https://tech.jxpress.net/entry/2021/11/17/112214

try:
  import pytorch_lightning as pl
except ModuleNotFoundError:
  !pip install pytorch_lightning
  import pytorch_lightning as pl
from torchmetrics import Accuracy, Precision, Recall, F1Score, MetricCollection

# pytorch lightningでは以下のコードでSeedを固定できます
pl.seed_everything(42)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.7.7-py3-none-any.whl (708 kB)
[K     |████████████████████████████████| 708 kB 31.2 MB/s 
[?25hCollecting tensorboard>=2.9.1
  Downloading tensorboard-2.10.1-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 44.1 MB/s 
[?25hRequirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (2022.8.2)
Collecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (6.0)
Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (21.3)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (4.1.1)
Requirement already satisfied: tqdm>=4.57.0 in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (4.64.1)
Requirement already satisfied: torch>=1.9.* in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (1.12.1+cu113)
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 69.6 MB/s 
[?25hRequirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch_lightning) (1.21.6)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.23.0)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.8.1)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.2.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (22.1.0)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (0.13.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (6.0.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.1.1)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (4.0.2)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.8.1)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.3.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch_lightning) (3.0.9)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (1.8.1)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (0.37.1)
Requirement already satisfied: protobuf<3.20,>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (3.17.3)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (1.0.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (57.4.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (1.35.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (3.4.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (0.4.6)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (1.2.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (1.48.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch_lightning) (0.6.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch_lightning) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch_lightning) (4.9)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch_lightning) (1.15.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch_lightning) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.9.1->pytorch_lightning) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.9.1->pytorch_lightning) (4.12.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.9.1->pytorch_lightning) (3.8.1)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch_lightning) (0.4.8)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2022.6.15)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.10)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.9.1->pytorch_lightning) (3.2.0)
Installing collected packages: torchmetrics, tensorboard, pyDeprecate, pytorch-lightning
  Attempting uninstall: tensorboard
    Found existing installation: tensorboard 2.8.0
    Uninstalling tensorboard-2.8.0:
      Successfully uninstalled tensorboard-2.8.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.2+zzzcolab20220719082949 requires tensorboard<2.9,>=2.8, but you have tensorboard 2.10.1 which is incompatible.[0m
Successfully installed pyDeprecate-0.3.2 pytorch-lightning-1.7.7 tensorboard-2.10.1 torchmetrics-0.9.3


INFO:pytorch_lightning.utilities.seed:Global seed set to 42





42
model_dict = {
    "AlexNet": AlexNet,
    "VGG16Net": VGG16Net,
    "GoogleNet": GoogleNet,
    "ResNet": ResNet,
    "DenseNet": DenseNet
}

def create_model(model_name, model_hparams):
    if model_name in model_dict:
        return model_dict[model_name](**model_hparams)
    else:
        assert False, f"Unknown model name \"{model_name}\". Available models are: {str(model_dict.keys())}"
class CIFARModule(pl.LightningModule):

    def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):
        """
        Inputs:
            model_name - Name of the model/CNN to run. Used for creating the model (see function below)
            model_hparams - Hyperparameters for the model, as dictionary.
            optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
            optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
        """
        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        # Create model
        self.model = create_model(model_name, model_hparams)
        # Create loss module
        self.loss_module = nn.CrossEntropyLoss()
        # Example input for visualizing the graph in Tensorboard
        self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)

        self.train_metrics = MetricCollection([Accuracy(), Precision(), Recall(), F1Score()], prefix='train_')
        self.val_metrics = MetricCollection([Accuracy(), Precision(), Recall(), F1Score()], prefix='val_')
        self.test_metrics = MetricCollection([Accuracy(), Precision(), Recall(), F1Score()], prefix='test_')

    def forward(self, imgs):
        # Forward function that is run when visualizing the graph
        return self.model(imgs)

    def configure_optimizers(self):
        # We will support Adam or SGD as optimizers.
        if self.hparams.optimizer_name == "Adam":
            # AdamW is Adam with a correct implementation of weight decay (see here for details: https://arxiv.org/pdf/1711.05101.pdf)
            optimizer = optim.AdamW(
                self.parameters(), **self.hparams.optimizer_hparams)
        elif self.hparams.optimizer_name == "SGD":
            optimizer = optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)
        else:
            assert False, f"Unknown optimizer: \"{self.hparams.optimizer_name}\""

        # We will reduce the learning rate by 0.1 after 100 and 150 epochs
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)

        # Logs the accuracy per epoch to tensorboard (weighted average over batches)
        preds_for_metrics = preds.argmax(dim=-1)
        #acc = (preds_for_metrics == labels).float().mean()
        #self.log('train_acc', acc, on_step=False, on_epoch=True)
        self.train_metrics(preds_for_metrics, labels)
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        self.log('train_loss', loss, logger=True)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        preds_for_metrics = preds.argmax(dim=-1)
        loss = self.loss_module(preds, labels)
        self.log('val_loss', loss, logger=True)
        # By default logs it per epoch (weighted average over batches)
        #acc = (labels == preds_for_metrics).float().mean()
        #self.log('val_acc', acc)
        self.val_metrics(preds_for_metrics, labels)
        self.log_dict(self.val_metrics, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        self.log('test_loss', loss, logger=True)

        # By default logs it per epoch (weighted average over batches), and returns it afterwards
        preds_for_metrics = self.model(imgs).argmax(dim=-1)
        #acc = (labels == preds_for_metrics).float().mean()
        #self.log('test_acc', acc)
        self.test_metrics(preds_for_metrics, labels)
        self.log_dict(self.test_metrics, logger=True)
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

def train_model(model_name, save_name=None, **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    if save_name is None:
        save_name = model_name
        
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),                          # Where to save models
                        accelerator="gpu" if str(device)=="cuda:0" else "cpu",
                        devices=1 if str(device)=="cuda:0" else 0,                                           # We run on a single GPU (if possible)
                        max_epochs=10,                                                                     # How many epochs to train for if no patience is set
                        callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
                                LearningRateMonitor("epoch"),                                            # Log Learning raate every epoch
                                EarlyStopping(monitor="val_loss", mode="min")],                          # early stopping
                        enable_progress_bar=True)    

    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
    
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = CIFARModule.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        model = CIFARModule(model_name=model_name, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = CIFARModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training
        
    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_Accuracy"], "val": val_result[0]["test_Accuracy"]}
    
    return model, result

各モデルの訓練

AlexNetとVGGはGAPでないため、入力画像のサイズに応じてパラメータを修正する必要がありますが、
GoogleNetとResNet、DenseNetはGAPにより入力画像のサイズに依存しないため今回はこの3つのモデルを学習してみます。

GoogleNetの学習

googlenet_model, googlenet_results = train_model(model_name="GoogleNet", 
                                                 model_hparams={"num_classes": 10, 
                                                                "act_fn_name": "relu"}, 
                                                 optimizer_name="Adam",
                                                 optimizer_hparams={"lr": 1e-3,
                                                                    "weight_decay": 1e-4})
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/gdrive/MyDrive/models/07_image_arch/GoogleNet/lightning_logs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params | In sizes       | Out sizes
--------------------------------------------------------------------------------
0 | model         | GoogleNet        | 260 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module   | CrossEntropyLoss | 0      | ?              | ?        
2 | train_metrics | MetricCollection | 0      | ?              | ?        
3 | val_metrics   | MetricCollection | 0      | ?              | ?        
4 | test_metrics  | MetricCollection | 0      | ?              | ?        
--------------------------------------------------------------------------------
260 K     Trainable params
0         Non-trainable params
260 K     Total params
1.043     Total estimated model params size (MB)



Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]
print("GoogleNet Results", googlenet_results)
GoogleNet Results {'test': 0.8115000128746033, 'val': 0.8240000009536743}
# Load tensorboard extension
%load_ext tensorboard
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
if os.path.isdir('/content/gdrive/'):
  %tensorboard --logdir /content/gdrive/MyDrive/models/07_image_arch/GoogleNet/lightning_logs
else:
  %tensorboard --logdir ../models/07_image_arch/GoogleNet/lightning_logs
<IPython.core.display.Javascript object>

Resnetの学習

resnet_model, resnet_results = train_model(model_name="ResNet", 
                                           model_hparams={"num_classes": 10,
                                                          "c_hidden": [16,32,64],
                                                          "num_blocks": [3,3,3],
                                                          "act_fn_name": "relu",
                                                          "block_name": "ResNetBlock"}, 
                                           optimizer_name="SGD",
                                           optimizer_hparams={"lr": 0.1,
                                                              "momentum": 0.9,
                                                              "weight_decay": 1e-4})
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/gdrive/MyDrive/models/07_image_arch/ResNet/lightning_logs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params | In sizes       | Out sizes
--------------------------------------------------------------------------------
0 | model         | ResNet           | 272 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module   | CrossEntropyLoss | 0      | ?              | ?        
2 | train_metrics | MetricCollection | 0      | ?              | ?        
3 | val_metrics   | MetricCollection | 0      | ?              | ?        
4 | test_metrics  | MetricCollection | 0      | ?              | ?        
--------------------------------------------------------------------------------
272 K     Trainable params
0         Non-trainable params
272 K     Total params
1.090     Total estimated model params size (MB)



Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]
resnetpreact_model, resnetpreact_results = train_model(model_name="ResNet", 
                                                       model_hparams={"num_classes": 10,
                                                                      "c_hidden": [16,32,64],
                                                                      "num_blocks": [3,3,3],
                                                                      "act_fn_name": "relu",
                                                                      "block_name": "PreActResNetBlock"}, 
                                                       optimizer_name="SGD",
                                                       optimizer_hparams={"lr": 0.1,
                                                                          "momentum": 0.9,
                                                                          "weight_decay": 1e-4},
                                                       save_name="ResNetPreAct")
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/gdrive/MyDrive/models/07_image_arch/ResNetPreAct/lightning_logs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params | In sizes       | Out sizes
--------------------------------------------------------------------------------
0 | model         | ResNet           | 272 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module   | CrossEntropyLoss | 0      | ?              | ?        
2 | train_metrics | MetricCollection | 0      | ?              | ?        
3 | val_metrics   | MetricCollection | 0      | ?              | ?        
4 | test_metrics  | MetricCollection | 0      | ?              | ?        
--------------------------------------------------------------------------------
272 K     Trainable params
0         Non-trainable params
272 K     Total params
1.089     Total estimated model params size (MB)



Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]
if os.path.isdir('/content/gdrive/'):
  %tensorboard --logdir /content/gdrive/MyDrive/models/07_image_arch/ResNet/lightning_logs
else:
  %tensorboard --logdir ../models/07_image_arch/ResNet/lightning_logs
<IPython.core.display.Javascript object>

DenseNetの学習

densenet_model, densenet_results = train_model(model_name="DenseNet", 
                                               model_hparams={"num_classes": 10,
                                                              "num_layers": [6,6,6,6],
                                                              "bn_size": 2,
                                                              "growth_rate": 16,
                                                              "act_fn_name": "relu"}, 
                                               optimizer_name="Adam",
                                               optimizer_hparams={"lr": 1e-3,
                                                                  "weight_decay": 1e-4})
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/gdrive/MyDrive/models/07_image_arch/DenseNet/lightning_logs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params | In sizes       | Out sizes
--------------------------------------------------------------------------------
0 | model         | DenseNet         | 239 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module   | CrossEntropyLoss | 0      | ?              | ?        
2 | train_metrics | MetricCollection | 0      | ?              | ?        
3 | val_metrics   | MetricCollection | 0      | ?              | ?        
4 | test_metrics  | MetricCollection | 0      | ?              | ?        
--------------------------------------------------------------------------------
239 K     Trainable params
0         Non-trainable params
239 K     Total params
0.957     Total estimated model params size (MB)



Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing: 0it [00:00, ?it/s]
if os.path.isdir('/content/gdrive/'):
  %tensorboard --logdir /content/gdrive/MyDrive/models/07_image_arch/DenseNet/lightning_logs
else:
  %tensorboard --logdir ../models/07_image_arch/DenseNet/lightning_logs
<IPython.core.display.Javascript object>

結果まとめ

各モデルで精度がほとんど変わらないっことがわかります。また、ResNetPreActは今回のような浅いモデルでは違いが出ていないこともわかります。

import tabulate
from IPython.display import display, HTML
all_models = [
    ("GoogleNet", googlenet_results, googlenet_model),
    ("ResNet", resnet_results, resnet_model),
    ("ResNetPreAct", resnetpreact_results, resnetpreact_model),
    ("DenseNet", densenet_results, densenet_model)
]
table = [[model_name,
          f"{100.0*model_results['val']:4.2f}%",
          f"{100.0*model_results['test']:4.2f}%",
          "{:,}".format(sum([np.prod(p.shape) for p in model.parameters()]))]
         for model_name, model_results, model in all_models]
display(HTML(tabulate.tabulate(table, tablefmt='html', headers=["Model", "Val Accuracy", "Test Accuracy", "Num Parameters"])))
Model Val Accuracy Test Accuracy Num Parameters
GoogleNet 82.40% 81.15% 260,650
ResNet 78.98% 79.35% 272,378
ResNetPreAct 80.82% 79.49% 272,250
DenseNet 81.02% 80.59% 239,146

以上です。

Discussion