🍊

ResNet18をRust&burnで実装する.

に公開

はじめに

機械学習クレートであるburnを用いて,ResNet18を実装した.データセットはMNISTを使用した.MNISTを使用した理由は,画像サイズが小さく,訓練の時間が短く済むからである.

この実装は自身の機械学習モデル実装の学習のために行った.

pytorchの公式実装コードを参考にして実装していく.

実装したコードはこちらのリポジトリにある.本記事ではモデルの実装に焦点をあてている.そのため,データセットの読み込みや学習ループなどについてはリポジトリを参照してほしい.
https://github.com/neruneruna7/my-resnet18

注意点として,MPSバックエンドでの実行を想定したコードだ.他の環境の人は,burnのドキュメントを参照して適切なバックエンドを使用するコードに変更してほしい.Cargo.tomlのburnクレートのfeaturesmain.rsを変更すればよいはずだ.

環境

Rust: 1.88
burn: 0.18
machine: macMini
macOS: 15.6

背景

論文などを読むときに,論文中の実験を追試したいがソースコードどこにあるかわからないことがある.そのため,論文中の記述をもとに自分でモデルを実装できるようになるため.

全体構造

おおよその全体構造は次のようになっている.Layerについては後ほど解説する.

層の具体的な説明は省略する.channelは,左側が入力チャネル数,右側が出力チャネル数と認識してほしい.3, 64であれば,入力チャネル数が3,出力チャネル数が64であることを意味する.

channel パラメータなど
畳み込み 3, 64 stride=2, padding=3, kernel size=7
バッチ正規化 64
ReLU活性化
最大値プール層 stride=2, padding=1, kernel size=3
Layer1 64, 64 stride=1
Layer2 64, 128 stride=2
Layer3 128, 256 stride=2
Layer4 256, 512 stride=2
適応型平均プール層 output_size=1
全結合層 512, 10 次元数2が入出力になる

基本的にテンソルの次元数は4になる.内訳は,[バッチサイズ,チャネル数,高さ,幅]である.
また,層のところには書いていないが,全結合層の前に次元数が4から2に変換される.この変換は,[バッチサイズ, チャネル数×高さ×幅]で行われる.

パラメータの意味

パラメータの意味について解説する.

  • stride: カーネルのピクセル移動幅
  • padding: 画像の周囲に処理の都合のため追加するピクセル数
  • kernel size: カーネルの大きさ
  • dilation: カーネルのピクセル間隔

このうち,dilationがテキストではややこしいので,図で説明する.
Oがカーネルの中心,Xがカーネルの処理対象,-がそのほかのピクセルだ.dilationが2のときは,処理対象のピクセルの間に間隔が空いていることがわかる.

dilation パターン
1 X X X
X O X
X X X
2 X - X - X
- - - - -
X - O - X
- - - - -
X - X - X

Layerの構造

resnet18のLayerは,BasicBlockと呼ばれるものが使われている.BasicBlockを2つ繋げて1つのLayerを構成している.ResNet50などではBottoleneckというものが使われているが,今回は省略する.

BasicBlockの構造

BasicBlockの構造は次のようになっている.特徴的な部分はショートカットだ.これはネットワークの入力を直接出力に加算.勾配消失問題という,ニューラルネットワークの層が深くなると発生する問題への対策となっている.

ここでin, outはBasicBlockに渡される整数型引数と考えてほしい.すなわち,Layerによって入出力チャネル数が異なるということだ.
これにより,テンソルの形状が異なるため,そのままでは加算できない.そこで,ショートカットに畳み込みとバッチ正規化を適用して,形状を合わせたうえで加算している.

channel パラメータ
畳み込み in, out stride=2 または1, padding=1, kernel size=3
バッチ正規化 out
畳み込み out, out stride=1, padding=1, kernel size=3
バッチ正規化 out
ショートカット Layer1以外では, 畳み込みとバッチ正規化適用
ReLU活性化

ショートカットは,Layer1以外では畳み込みとバッチ正規化を適用する.これはなぜか.それは,Layer1以外では入力チャネルと出力チャネルの値が異なるからだ.前述の全体構造の表を見てみてほしい.

実装

では実際の実装に移る.諸々のuse ...は次のとおりだ.

use crate::data::MnistBatch;についての解説は省略する.githubリポジトリを参照してほしい.

use burn::{
    nn::{
        BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
        conv::{Conv2d, Conv2dConfig},
        loss::CrossEntropyLossConfig,
        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig, MaxPool2d, MaxPool2dConfig},
    },
    prelude::*,
    tensor::backend::AutodiffBackend,
    train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

use crate::data::MnistBatch;

BasicBlockの実装

#[derive(Module, Debug)]
struct BasicBlock<B: Backend> {
    conv1: Conv2d<B>,
    // 正規化レイヤ
    bn1: BatchNorm<B, 2>,
    conv2: Conv2d<B>,
    bn2: BatchNorm<B, 2>,
    shortcut: Option<DownSample<B>>,
    activation: Relu,
}

impl<B: Backend> BasicBlock<B> {
    fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
        let identity = x.clone();
        // ショートカットを先に計算(現在の実装の流れ)
        let shortcut = if let Some(shortcut) = &self.shortcut {
            shortcut.forward(x.clone())
        } else {
            x.clone()
        };

        // メイン経路
        let x = self.conv1.forward(x);
        let x = self.bn1.forward(x);
        let x = self.activation.forward(x);

        let x = self.conv2.forward(x);
        let x = self.bn2.forward(x);

        // ここで形状が合わない場合は明示的にログを出して panic する
        if x.dims() != shortcut.dims() {
            panic!(
                "BasicBlock: shape mismatch between main and shortcut \n
                Shape mismatch before add: main={:?}, shortcut={:?}",
                x.dims(),
                shortcut.dims(),
            );
        }

        let x = x + shortcut;

        self.activation.forward(x)
    }
}

#[derive(Config, Debug)]
struct BasicBlockConfig {
    /// 入力チャネル数
    in_planes: usize,
    /// 出力チャネル数
    out_planes: usize,
    /// カーネルの移動距離
    #[config(default = "[1, 1]")]
    stride: [usize; 2],
    #[config(default = 1)]
    dilation: usize,
    #[config(default = "None")]
    downsample: Option<DownSampleConfig>,
}

impl BasicBlockConfig {
    /// 入力チャネル数,出力チャネル数,デバイス
    fn init<B: Backend>(&self, device: &B::Device) -> BasicBlock<B> {
        BasicBlock {
            conv1: Conv2dConfig::new([self.in_planes, self.out_planes], [3, 3])
                .with_stride(self.stride)
                .with_padding(PaddingConfig2d::Explicit(1, 1))
                .with_bias(false)
                .init(device),
            bn1: BatchNormConfig::new(self.out_planes).init(device),
            conv2: Conv2dConfig::new([self.out_planes, self.out_planes], [3, 3])
                .with_padding(PaddingConfig2d::Explicit(1, 1))
                .with_bias(false)
                .init(device),
            bn2: BatchNormConfig::new(self.out_planes).init(device),
            // Use the block's input/output channel sizes for the shortcut 1x1 conv
            shortcut: self.downsample.as_ref().map(|ds| ds.init(device)),
            activation: Relu::new(),
        }
    }
}

#[derive(Module, Debug)]
struct DownSample<B: Backend> {
    conv: Conv2d<B>,
    bn: BatchNorm<B, 2>,
}

impl<B: Backend> DownSample<B> {
    fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
        let x = self.conv.forward(x);
        self.bn.forward(x)
    }
}
#[derive(Config, Debug)]
struct DownSampleConfig {
    in_planes: usize,
    out_planes: usize,
    #[config(default = "[1, 1]")]
    kernel_size: [usize; 2],
    stride: [usize; 2],
}

impl DownSampleConfig {
    fn init<B: Backend>(&self, device: &B::Device) -> DownSample<B> {
        DownSample {
            conv: Conv2dConfig::new([self.in_planes, self.out_planes], self.kernel_size)
                .with_stride(self.stride)
                .init(device),
            bn: BatchNormConfig::new(self.out_planes).init(device),
        }
    }
}

モデルの構造とその設定が分離されているのがわかるだろうか.BasicBlock構造体が実際のモデルを表現し,BasicBlockConfig構造体がモデルの設定を表現している.また,BasicBlock構造体にはModuleトレイトが,BasicBlockConfig構造体にはConfigトレイトがderiveされている.細かいことはburnのドキュメントを参照してほしいが,これにより,モデルのパラメータの保存や読み込み,設定のデフォルト値設定が容易になる.

設定のデフォルト値設定について軽く触れる.BasicBlockConfig構造体のフィールドには,#[config(default = ...)]という属性がついているものがある.これは,そのフィールドのデフォルト値を指定している.例えば,dilationフィールドは,デフォルト値が1に設定されている.これにより,自動実装されるnewメソッドには,dilationフィールドを指定しなくてもよくなる.自分で指定したい場合は,with_dilationメソッドを使って,dilationフィールドを指定できる.
ただし,プリミティブな値以外はそのままではデフォルト値を指定できないことに注意してほしい.例えば,dilationフィールドは数値配列型だが,"[1,1]"のように文字列として指定している.

ショートカットに適用される畳み込みとバッチ正規化は,DownSampleに実装されている.

Layerの実装

BasicBlockを2つつなげただけである.

2つめのBasicBlockでは,入力チャネルと出力チャネルが必ず同じになるため,BasicBlockのDownSampleを指定する必要がない.また,strideもデフォルト値のままである.

#[derive(Module, Debug)]
struct ResNetLayer<B: Backend> {
    blocks: [BasicBlock<B>; 2],
}

impl<B: Backend> ResNetLayer<B> {
    fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
        let x = self.blocks[0].forward(x);

        self.blocks[1].forward(x)
    }
}

#[derive(Config, Debug)]
struct ResNetLayerConfig {
    in_planes: usize,
    out_planes: usize,
    stride: [usize; 2],
}

impl ResNetLayerConfig {
    fn init<B: Backend>(&self, device: &B::Device) -> ResNetLayer<B> {
        let downsample = if self.stride != [1, 1] || self.in_planes != self.out_planes {
            Some(DownSampleConfig::new(
                self.in_planes,
                self.out_planes,
                self.stride,
            ))
        } else {
            None
        };
        ResNetLayer {
            blocks: [
                BasicBlockConfig::new(self.in_planes, self.out_planes)
                    .with_stride(self.stride)
                    .with_downsample(downsample)
                    .init(device),
                BasicBlockConfig::new(self.out_planes, self.out_planes).init(device),
            ],
        }
    }
}

全体の実装

#[derive(Module, Debug)]
pub struct ResNet18<B: Backend> {
    conv1: Conv2d<B>,
    bn1: BatchNorm<B, 2>,
    activation: Relu,
    maxpool: MaxPool2d,
    layer1: ResNetLayer<B>,
    layer2: ResNetLayer<B>,
    layer3: ResNetLayer<B>,
    layer4: ResNetLayer<B>,
    avgpool: AdaptiveAvgPool2d,
    fc: Linear<B>,
}

impl<B: Backend> ResNet18<B> {
    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
        let x = self.conv1.forward(x);
        let x = self.bn1.forward(x);
        let x = self.activation.forward(x);
        let x = self.maxpool.forward(x);

        let x = self.layer1.forward(x);
        let x = self.layer2.forward(x);
        let x = self.layer3.forward(x);
        let x = self.layer4.forward(x);

        let x = self.avgpool.forward(x);
        // let [batch_size, channel, height, width] = x.dims();
        // let x = x.reshape([batch_size, channel * height * width]);
        let x = x.flatten(1, 3);

        self.fc.forward(x)
    }

    pub fn forward_classification(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
        let targets = batch.targets;
        let [batch_size, height, width] = batch.images.dims();
        // チャネル数1を加えて,4次元に変換
        // MNISTはグレースケールなので,チャネル数1になる.RGB画像であれば3になる
        let image = batch
            .images
            .reshape([batch_size, 1, height, width])
            .detach();

        let output = self.forward(image);
        let loss = CrossEntropyLossConfig::new()
            .init(&output.device())
            .forward(output.clone(), targets.clone());

        ClassificationOutput::new(loss, output, targets)
    }
}

#[derive(Config, Debug)]
pub struct ResNet18Config {
    num_classes: usize,
    input_channel: usize,
    #[config(default = 1)]
    block_expansion: usize,
}

impl ResNet18Config {
    pub fn init<B: Backend>(self, device: &B::Device) -> ResNet18<B> {
        ResNet18 {
            //channels[ ]は,多分入力チャネル,出力チャネル
            conv1: Conv2dConfig::new([self.input_channel, 64], [7, 7])
                .with_stride([2, 2])
                .with_padding(PaddingConfig2d::Explicit(3, 3))
                .with_bias(false)
                .with_initializer(nn::Initializer::KaimingNormal {
                    gain: (2.0_f64).sqrt(),
                    fan_out_only: true,
                })
                .init(device),
            bn1: BatchNormConfig::new(64).init(device),
            activation: Relu::new(),
            maxpool: MaxPool2dConfig::new([3, 3])
                .with_strides([2, 2])
                .with_padding(PaddingConfig2d::Explicit(1, 1))
                .init(),

            layer1: ResNetLayerConfig::new(64, 64, [1, 1]).init(device),
            layer2: ResNetLayerConfig::new(64, 128, [2, 2]).init(device),
            layer3: ResNetLayerConfig::new(128, 256, [2, 2]).init(device),
            layer4: ResNetLayerConfig::new(256, 512, [2, 2]).init(device),
            avgpool: AdaptiveAvgPool2dConfig::new([1, 1]).init(),
            fc: LinearConfig::new(512 * self.block_expansion, self.num_classes).init(device),
        }
    }
}

最後に,学習を行うためのトレイトを実装する.

impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for ResNet18<B> {
    fn step(&self, item: MnistBatch<B>) -> burn::train::TrainOutput<ClassificationOutput<B>> {
        let item = self.forward_classification(item);

        TrainOutput::new(self, item.loss.backward(), item)
    }
}

impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for ResNet18<B> {
    fn step(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {
        self.forward_classification(item)
    }
}

実行結果

リポジトリの通りに実行すると,次のメトリクスが得られる.

| Split | Metric          | Min.     | Epoch    | Max.     | Epoch    |
|-------|-----------------|----------|----------|----------|----------|
| Train | Accuracy        | 95.638   | 1        | 99.213   | 5        |
| Train | CPU Memory      | 13.915   | 1        | 14.614   | 2        |
| Train | CPU Temperature | NaN      | 1        | NaN      | 5        |
| Train | CPU Usage       | 34.847   | 4        | 36.586   | 2        |
| Train | Loss            | 0.025    | 5        | 0.140    | 1        |
| Valid | Accuracy        | 97.530   | 1        | 98.980   | 4        |
| Valid | CPU Memory      | 14.421   | 5        | 14.639   | 1        |
| Valid | CPU Temperature | NaN      | 1        | NaN      | 5        |
| Valid | CPU Usage       | 56.862   | 4        | 63.358   | 1        |
| Valid | Loss            | 0.034    | 4        | 0.076    | 1        |

参考文献

  1. https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
  2. https://qiita.com/TaiseiYamana/items/b3e97da112d912c66563
  3. https://arxiv.org/abs/1512.03385
  4. https://qiita.com/teacat/items/d6b24fb5353872f6b3a3
  5. https://burn.dev/books/burn/overview.html
  6. https://burn.dev/docs/burn/index.html

Discussion