🦠

Mosic Augmentationを使ったSemantic Segmentationモデル実装

2023/12/03に公開

初学者向きの日本語記事で細かい設計に踏み込んだ解説が少ないので書きます。
タスクに関わる部分は可能な限りtorchtorchvisionで完結しシンプルになるように心がけました。この記事で説明することは主に次の内容です。

  • セグメンテーションタスクの概観
  • データ作成
  • データセットクラスの作成
  • セグメンテーションモデルの作成
  • 損失関数と評価指標の作成
  • 学習ループ作成
  • 学習と検証

作業環境

私はpyenvなどよりDocker上に実験環境を作る方が好みなので、以下のDockerfileを使ってコンテナで作業しているが、おそらくGoogle Colabなどを用いても同じことはできると思う。
Colabではpip install ...の部分をipynb上にコピペすればライブラリを用意できる。

ARG PYTORCH="2.1.1"
ARG CUDA="12.1"
ARG CUDNN="8"

FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
    TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
    CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
    FORCE_CUDA="1" \
    DEBIAN_FRONTEND=noninteractive

# Install the required packages
RUN apt-get update
RUN apt-get install -y ffmpeg git ninja-build libsm6 libxext6 libglib2.0-0 libsm6 libxrender-dev libxext6 \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

RUN pip install lightning matplotlib seaborn onnx onnxruntime onnxoptimizer scikit-learn \ 
    && pip install timm torchinfo torchmetrics transformers polars
CMD ["/bin/bash"]

セマンティックセグメンテーション

最も基本的なタスクとして画像分類が有名だが、分類モデルは入力画像\texttt{img} \in [0,255]^{3 × H × W}に対して分類クラスのベクトル\texttt{out} \in [0,1]^{\texttt{cls}}を出力し、セマンティックセグメンテーションモデルは入力画像\texttt{img} \in [0,255]^{3 × H × W}に対して分類クラス数のチャンネルを持つ配列(スーパー画像)\texttt{out} \in [0,1]^{\texttt{cls} × H × W}が出力される。すなわちピクセル単位で分類するのがセマンティックセグメンテーションだという認識で良さそう。

今後混乱がないように書くと、PyTorchの画像系NNモデルは一般的に入力次元が[Batch(B), Channel(C), Height(H), Width(W)]となっている点に注意。Batchは学習時のミニバッチの枚数。

これから実行するモデルは、画像imgと訓練データ(マスク)mskを受け取り、教師あり学習を行うことで、imgからマスクを推定する。まずはデータセットを作っていく。

データセットの作成

今回は細胞の輪郭から細胞の画素と種類を推定するタスクを実行する。
データセットは以下のLiveCell 2021のCOCO Datasetからマスクを取り出して、画像に整形して作ることにするが、形式さえ同じなら何を使ってもいい。
A172, BT-474, BV-2, Huh7, MCF7, SH-SY5Y, SkBr3, SK-OV-3の8クラスの細胞があり、それぞれ1つの画像の中に1クラスが写っている。

[1] Christoffer Edlund, Timothy R. Jackson, Nabeel Khalid, Nicola Bevan, Timothy Dale, Andreas Dengel, Sheraz Ahmed, Johan Trygg & Rickard Sjögren
"LIVECell—A large-scale dataset for label-free live cell segmentation"
https://www.nature.com/articles/s41592-021-01249-6

https://sartorius-research.github.io/LIVECell/

[1] 細胞画像の特徴空間上での分布

[1] マスクの例

このデータセットを整形したのがこちら。
整形内容

  • 全画像を256×256で右上、左上、右下、左下部分をクロップ
  • COCOアノテーションファイルのポリゴン頂点をグレースケール画像化

整形済み画像データとマスクデータ

左上からA172, BT-474, BV-2, Huh7、下段にMCF7, SH-SY5Y, SkBr3, SK-OV-3となっている。素人には判別がつかない。
マスクも同じクラスの並びにしている。それぞれのクラスのマスクは輝度値が10ずつ異なる。これによりプログラム部分で[0,1]^{\texttt{cls} × H × W}の形に分けやすくしている。

自作データセットを作る場合はこのようにカラー画像と対応するグレースケールマスク画像を作ることになる。今回は同一データ内に異なるクラスのマスクが無い(後で工夫して混合する)が、一般的には混ざっていても問題ない。グレースケールでなくても、カラーパレットのRGB値を持っているなら原理的には同様の操作が可能なので、クラス数が255より多い場合などはカラーを使う手もある。

参考までに、他のデータセットとしてCOCO Stuffも併載する。

[2] Holger Caesar, Jasper Uijlings, Vittorio Ferrari,
"COCO-Stuff: Thing and Stuff Classes in Context"
https://arxiv.org/abs/1612.03716

https://github.com/nightrome/cocostuff

COCO Stuffの画像とマスク

[2] COCO Stuffのクラス情報など

データセットを整形したら、次のようにデータを配置する。今回、trainとvalは両方trainに混ぜてしまい、testのみ隔離する。また、画像と対応するマスク画像は同じ名前にすると便利。

experiment/
 └─ dataset/
...  ├─ images_train/
         ├─ train_cls1_0000000_0.jpg
	 └─ train_cls1_0000000_1.jpg ...
     ├─ images_test/
         ├─ test_cls1_0000000_0.jpg
	 └─ test_cls1_0000000_1.jpg ...
     ├─ masks_train/
         ├─ train_cls1_0000000_0.png
	 └─ train_cls1_0000000_1.png ...
     └─ images_test/
         ├─ test_cls1_0000000_0.png
	 └─ test_cls1_0000000_1.png ...

これで一旦準備ができたので、次はデータをPythonへ読み込むプログラム部分を作っていく。

データの読み込み

コードを書く前に、この記事で使用するライブラリ一覧を示す。また登場次第適宜紹介する。

  • import glob: ファイル名を部分一致で検索しリストにして持ってくる
  • import multiprocessing: おまじない
  • import warnings: おまじない
  • import matplotlib.pyplot as plt: テンソルを画像化するとき使う
  • import torch: いつもの
  • from torch import nn: モデル構築に使う
  • from torch.nn import functional as F: テンソルに関数を適用するとき使う
  • import torchvision: 画像読み込みなどでつかう
  • from torchvision.transforms import v2 as transforms: 画像変換
  • from torchvision.transforms.v2 import functional as transformsF: 画像変換関数
  • import lightning as pl: pytorchの最適化ループを簡単にしてくれる
  • import timm: 事前学習済み画像分類NNの動物園

torch.utils.data.Dataset部分

データの対をインデックスアクセスで取ってこれるクラスを作る。
記事ではクラス数などを直書きしているが、実際は引数などに持たせるように書き、クラスなどは他のファイルから呼び出している。

train.py
class LivecellMosicDataset(torch.utils.data.Dataset):
    def __init__(self, images, masks, is_train=False):
        # input:
	#    images: 画像のpathのリスト
	#    masks: マスク画像のpathのリスト
	#    is_train: bool trainデータの場合True
        self.images = images
        self.masks = masks
	
	# trainデータならモザイク組み込みするインスタンスを作る
	# testデータなら普通に読み込むインスタンスを作る
        self.load = MosicImageLoader(images, masks) if is_train == True else ImageLoader(images, masks)
	
	# 画像とマスク画像に対して画像処理(data augmentation)を行う
        self.augmentation = MyResizeFlipCrop(size=IMGSIZE[0], normrize_img=True)
        return None
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # 画像とマスクを読み込む
        img, msk = self.load(idx)
	# 画像とマスクに対して画像処理を行う
        img, msk = self.augmentation(img, msk)

        # マスク画像は輝度値10*NUM_CLASSなので、1刻みに変える
        msk = transforms.functional.rgb_to_grayscale(msk)//10
	# マスク画像[1,H,W]を要素{0, 1}の8クラスのマスクを表現するテンソル[8,H,W]へ変換する
        msk = maskonehot(msk, 8)
        
        return img.to(torch.float), msk.to(torch.float)

この中に出てくるImageLoader, MosicImageLoader, MyResizeFlipCrop, maskonehotはこれから作っていく。
概観としては、次のように動作して欲しいというものである。

dataset = LivecellMosicDataset([dataset/a.jpg ...], [dataset/a.png ...])
len(dataset)  # = 5678
img, msk = dataset[123]
img  # tensor [3,256,256]
msk  # tensor [8,256,256]

ImageLoaderの作成

クラスImageLoaderは単純に画像とマスクのpathを受け取って画像とマスクを返す。

class ImageLoader():
    def __init__(self, pathes_img, pathes_msk):
        self.pathes_img = pathes_img
        self.pathes_msk = pathes_msk
        return None
    
    def __call__(self, idx):
        img = torchvision.io.read_image(self.pathes_img[idx]).to(torch.float)
        msk = torchvision.io.read_image(self.pathes_msk[idx]).to(torch.float)
        return img, msk

MosicImageLoader部分

イレギュラーなので読み飛ばしても問題ない。
今回例で使うデータセットは1枚の画像に1クラスのマスクしか無いため、特徴の増幅のためMosaic data augmentationを行った。

[3] Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao,
"YOLOv4: Optimal Speed and Accuracy of Object Detectio"
https://arxiv.org/abs/2004.10934

[3] YOLOv4のデータ拡張

アクセスした画像のインデックスから4つの画像を決定的に持ってきて、結合したのちランダムにクロップすることで実装する。従って、出力画像はオリジナルの2倍の大きさになる点には注意。

class MosicImageLoader():
    def __init__(self, pathes_img, pathes_msk):
        # input:
        #     pathes_img: 画像のpathのリスト
	#     pathes_msk: マスク画像のpathのリスト
        self.pathes_img = pathes_img
        self.pathes_msk = pathes_msk
        self.l = len(pathes_img)
	# アクセスした画像以外のpathを決定する乱数を決定
        self.rand = torch.randint(high=len(pathes_img), size=(3,))
        return None

    def mosicconcat(self, img_list):
        # input List[tensor[C,H,W] × 4]
        img1 = torch.cat([img_list[0], img_list[1]], dim=2)
        img2 = torch.cat([img_list[2], img_list[3]], dim=2)
        img = torch.cat([img1, img2], dim=1)
        return img
    
    def __call__(self, idx):
        img = self.mosicconcat([
            torchvision.io.read_image(self.pathes_img[idx]),
            torchvision.io.read_image(self.pathes_img[(idx+self.rand[0])%self.l]),
            torchvision.io.read_image(self.pathes_img[(idx+self.rand[1])%self.l]),
            torchvision.io.read_image(self.pathes_img[(idx+self.rand[2])%self.l]),
            ]).to(torch.float)
        msk = self.mosicconcat([
            torchvision.io.read_image(self.pathes_msk[idx]),
            torchvision.io.read_image(self.pathes_msk[(idx+self.rand[0])%self.l]),
            torchvision.io.read_image(self.pathes_msk[(idx+self.rand[1])%self.l]),
            torchvision.io.read_image(self.pathes_msk[(idx+self.rand[2])%self.l]),
            ]).to(torch.float)
        return img, msk

MyResizeFlipCrop部分

torchvision.transformsの変形はマスク画像に同じ変形を適用できない。(できるらしいがうまく行かなかった)
そのため、入力された画像とマスクに同じ変形を行うクラスを作成する。

class MyResizeFlipCrop():
    def __init__(self, size=256, clopscale=0.5, hflip=0.5, vflip=0.5, normrize_img=True):
        # input:
	#    size: 変換後出力される画像のサイズ
	#    clopscale: 元画像に対してどの割合でクロップするか
	#    hflip, vflip: 画像の反転確率
	#    normrize_img: 画像を標準化するかどうか(基本的にはする)
        self.size = size
        self.scale = clopscale
        self.hfp = hflip
        self.vfp = vflip
        self.normrize = Transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) if normrize_img == True else nn.Identity()
        return None
    
    def __call__(self, img, msk):
        # input [3,H,W], [C,H,W]
        scale = self.scale
        y, x = torch.randint(high=int(self.size/scale-self.size-1), size=(2,))
        img = transformsF.resize(
            img, 
            int(self.size/scale), 
            interpolation=torchvision.transforms.InterpolationMode.BICUBIC)[:,x:x+self.size, y:y+self.size]
        img = self.normrize(img)
        msk = transformsF.resize(
            msk, 
            int(self.size/scale),
            interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT)[:,x:x+self.size, y:y+self.size]
        if torch.rand(1) < self.hfp:
            img = transformsF.horizontal_flip(img)
            msk = transformsF.horizontal_flip(msk)
        if torch.rand(1) < self.vfp:
            img = transformsF.vertical_flip(img)
            msk = transformsF.vertical_flip(msk)
        return img, msk

maskonehot部分

マスク画像を、モデル出力と比較するためのテンソルに変換する。
この記事ではグレースケール画像のマスク画像(成分は0~255)を8クラスのone-hotなテンソルにする。
以下の記事は非常に具体的な図が載っていて理解の助けになると思う。

https://www.jeremyjordan.me/semantic-segmentation/

def maskonehot(msk, n_cls: int):
    # input [1,H,W] -> output [CLS,H,W]
    # f: \mathbb{N}^{1×H×W} -> [0,1]^{CLS×H×W}
    return torch.concat([torch.where(msk == c+1, 1, 0) for c in range(n_cls)], dim=0)
   
mskimg = torch.zeros(1,256,256)
msktensor = maskonehot(mskimg, 8)
msktensor  # tensor [8,256,256]

以上によりデータセットの準備は完了した。次は学習モデルについて見ていく。

セグメンテーションモデルの作成

普通は色々なライブラリから事前学習済みモデルを持ってくるのが効率がいいが、この記事では学習のためプログラム本体を書いて行く。
方針としては、timmからweightごとEfficientNetv2もモデルをビルドし、これをバックボーンとして3種類の解像度の特徴マップを引っこ抜いてきて、LightHamの分類器に入力する形にする。
つまりEfficientSegNextを作る。

Mingxing Tan, Quoc V. Le,
"EfficientNetV2: Smaller Models and Faster Training"
https://arxiv.org/abs/2104.00298

Zhengyang Geng, Meng-Hao Guo, Hongxu Chen, Xia Li, Ke Wei, Zhouchen Lin,
"Is Attention Better Than Matrix Decomposition?"
https://arxiv.org/abs/2109.04553

https://github.com/Gsunshine/Enjoy-Hamburger

[4] Meng-Hao Guo, Cheng-Ze Lu, Qibin Hou, Zhengning Liu, Ming-Ming Cheng, Shi-Min Hu
"SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation"
https://arxiv.org/abs/2209.08575

SegNextの解説は以前記事にしたので詳しくは以下に譲る

https://zenn.dev/inaturam/articles/d2c0363bd7c1aa

以下の図の(a)と(c)を融合した形をつくることで、EfficientNetv2の速度性能と行列分解デコーダの速度性能が両方備わり最強に見える。

[4] エンコーダデコーダの比較

モデルの詳細な計算処理などは深く触れないが、モジュール単位で何がやり取りされているかにのみ注目する。

EfficientNetv2 Backbone

class EfficientNetBackbone(nn.Module):
    def __init__(self):
        super().__init__()
	# ここでEfficientNetv2をビルド
        model = timm.create_model(
            "tf_efficientnetv2_s.in21k_ft_in1k",
            pretrained=True,
            )
	# 以下でモデルを分解
        self.stem = nn.Sequential(
            model.conv_stem,
            model.bn1
            )
        self.block0 = model.blocks[0]
        self.block1 = model.blocks[1]
        self.block2 = model.blocks[2]
        self.block3 = model.blocks[3]
        self.block4 = model.blocks[4]
	# 256×256入力ではblock5に殆ど位置情報が残らないため、断腸の思いで捨てる
        # self.block5 = model.blocks[5]

        # channel数をデコーダと合わせるためのMLP層
        self.mlp1 = nn.Conv2d( 48,  64, 1)
        self.mlp2 = nn.Conv2d( 64, 128, 1)
        self.mlp3 = nn.Conv2d(160, 256, 1)
        return None
    
    def forward(self, x):
        # input:
	#    x: 画像入力 [B,3,H,W]
        x = self.stem(x)
        x = self.block0(x)  # output [B, 24, H/2, W/2]
        x = self.block1(x)  # output [B, 48, H/4, W/4]
        x1 = self.mlp1(x)
        x = self.block2(x)  # output [B, 64, H/8, W/8]
        x2 = self.mlp2(x)
        x = self.block3(x)  # output [B,128,H/16,W/16]
        x = self.block4(x)  # output [B,160,H/16,W/16]
        x3 = self.mlp3(x)
        # x = self.block5(x)  # (output [B,256,H/32,W/32])

        return x1, x2, x3  # HW/4, HW/8, HW/16 

これにより解像度がオリジナルの1/4, 1/8, 1/16の特徴マップを出力できるので、これをデコーダへ入力する。

LightHam Decoder

入力された3つの特徴マップの大きさを揃えて結合し、これをMLPと行列分解で特徴を分離しクラス数のテンソルに変換する。

\texttt{lightham}: \mathbb{R}^{B×64×H/4×W/4} \mathbb{R}^{B×128×H/8×W/8} \mathbb{R}^{B×256×H/16×W/16} \longmapsto \mathbb{R}^{B×8×H/4×W/4}

最後に、空間方向を4倍してオリジナルのマスクテンソルの大きさと揃えて、Sigmoid関数で[0,1]に変換すれば出力形が完成する。

行列分解部分+MLP部分
class NMF2D(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()
        self.spatial = True
        self.S = 1
        self.D = 512
        self.R = 64
        self.train_steps = 6
        self.eval_steps = 7
        self.inv_t = 1  # default 100
        self.eta = 0.9
        self.rand_init = True

        # Use for device verification such as `self.tensordevice.device == torch.device('cpu')`
        self.tensordevice = torch.device(device)  # nn.Parameter(torch.empty(0))
        return None

    def _build_bases(self, B, S, D, R):
        # _MatrixDecomposition2DBase default empty
        if self.tensordevice == torch.device('cuda'):
            bases = torch.rand((B * S, D, R)).cuda()
        else:
            print("cpu mode")
            bases = torch.rand((B * S, D, R))
        bases = F.normalize(bases, dim=1)
        return bases

    def local_step(self, x, bases, coef):
        if self.tensordevice == torch.device('cuda'):
            bases = bases.cuda()
        else:
            print("cpu mode")
        # _MatrixDecomposition2DBase default empty
        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        numerator = torch.bmm(x.transpose(1, 2), bases)
        # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
        # Multiplicative Update
        coef = coef * numerator / (denominator + 1e-6)

        # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
        numerator = torch.bmm(x, coef)
        # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
        denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
        # Multiplicative Update
        bases = bases * numerator / (denominator + 1e-6)

        return bases, coef

    def local_inference(self, x, bases):
        if self.tensordevice == torch.device('cuda'):
            bases = bases.cuda()
        else:
            print("cpu mode")
        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        coef = torch.bmm(x.transpose(1, 2), bases)
        coef = F.softmax(self.inv_t * coef, dim=-1)

        steps = self.train_steps if self.training else self.eval_steps
        for _ in range(steps):
            bases, coef = self.local_step(x, bases, coef)

        return bases, coef

    def compute_coef(self, x, bases, coef):  
        if self.tensordevice == torch.device('cuda'):
            bases = bases.cuda()
        else:
            print("cpu mode")
        # _MatrixDecomposition2DBase default empty

        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        numerator = torch.bmm(x.transpose(1, 2), bases)
        # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
        # multiplication update
        coef = coef * numerator / (denominator + 1e-6)

        return coef

    def forward(self, x, return_bases=False):
        B, C, H, W = x.shape
        # (B, C, H, W) -> (B * S, D, N)
        if self.spatial:
            D = C // self.S
            N = H * W
            x = x.view(B * self.S, D, N)
        else:
            D = H * W
            N = C // self.S
            x = x.view(B * self.S, N, D).transpose(1, 2)
        if not self.rand_init and not hasattr(self, 'bases'):
            bases = self._build_bases(1, self.S, D, self.R)
            self.register_buffer('bases', bases)
        # (S, D, R) -> (B * S, D, R)
        if self.rand_init:
            bases = self._build_bases(B, self.S, D, self.R)
        else:
            bases = self.bases.repeat(B, 1, 1)
        bases, coef = self.local_inference(x, bases)
        # (B * S, N, R)
        coef = self.compute_coef(x, bases, coef)
        # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
        x = torch.bmm(bases, coef.transpose(1, 2))
        # (B * S, D, N) -> (B, C, H, W)
        if self.spatial:
            x = x.view(B, C, H, W)
        else:
            x = x.transpose(1, 2).view(B, C, H, W)
        # (B * H, D, R) -> (B, H, N, D)
        bases = bases.view(B, self.S, D, self.R)

        return x


class Hamburger(nn.Module):
    def __init__(self, ham_channels, device="cuda"):
        super().__init__()

        self.hamburger = nn.Sequential(
            nn.Conv2d(ham_channels, ham_channels, 1),
            nn.ReLU(inplace=True),
            NMF2D(device=device),
            nn.Conv2d(ham_channels, ham_channels, 1)
        )
        self.relu = nn.ReLU(inplace=True)
        return None

    def forward(self, x):
        ham = self.hamburger(x)
        ham = self.relu(x + ham)
        return ham


class LightHamHead(nn.Module):
    def __init__(self, in_channels=[128, 320, 512], ham_channels=512, channels=512, num_classes=150, device="cuda"):
        super().__init__()
        input_chs = sum(in_channels)
        output_chs = channels

        self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bicubic')
        self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bicubic')
        self.squeeze = nn.Conv2d(input_chs, ham_channels, 1)
        self.hamburger = Hamburger(ham_channels, device=device)
        self.align = nn.Conv2d(ham_channels, output_chs, 1)
        self.cls_seg = nn.Conv2d(output_chs, num_classes, 1)
        return None

    def forward(self, p1, p2, p3):
        inputs = torch.cat([p1, self.upsample_x2(p2), self.upsample_x4(p3)], dim=1)
        x = self.squeeze(inputs)
        x = self.hamburger(x)
        output = self.align(x)
        output = self.cls_seg(output)
        return output
class EfficientSegnext(nn.Module):
    def __init__(self, num_class=8, device="cuda"):
        super().__init__()
	# バックボーンとデコーダを結合する
        self.backbone = EfficientNetBackbone()
        self.head = LightHamHead(
            in_channels=[64,128,256], 
            ham_channels=256, 
            channels=256, 
            num_classes=num_class,
            device=device
            )
	# テンソルのサイズをオリジナルの大きさに戻す
        self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bicubic')
	# 値を[0,1]にする
        self.sigmoid = nn.Sigmoid()
        return None

    def forward(self, x):
        p1, p2, p3 = self.backbone(x)
        out = self.head(p1, p2, p3)
        out = self.upsample_x4(out)
        out = self.sigmoid(out)
        return out

これでモデルは完成である。バックボーンのEfficientNetv2は学習済みの重みを持っているので、ある程度セグメンテーションの学習時間を短縮できる。
このモデルは次のように使える。

model = EfficientSegnext(device="cpu")
model  # モデルの全体像は下の折りたたみ欄に記載
out = model(torch.zeros(1,3,256,256))
out.shape  # tensor [1,8,256,256]
モデルの全体像
EfficientSegnext(
  (backbone): EfficientNetBackbone(
    (stem): Sequential(
      (0): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (1): BatchNormAct2d(
        24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): SiLU(inplace=True)
      )
    )
    (block0): Sequential(
      (0): ConvBnAct(
        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (drop_path): Identity()
      )
      (1): ConvBnAct(
        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (drop_path): Identity()
      )
    )
    (block1): Sequential(
      (0): EdgeResidual(
        (conv_exp): Conv2dSame(24, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn1): BatchNormAct2d(
          96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (1): EdgeResidual(
        (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (2): EdgeResidual(
        (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (3): EdgeResidual(
        (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
    )
    (block2): Sequential(
      (0): EdgeResidual(
        (conv_exp): Conv2dSame(48, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn1): BatchNormAct2d(
          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (1): EdgeResidual(
        (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (2): EdgeResidual(
        (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (3): EdgeResidual(
        (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
    )
    (block3): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2dSame(256, 256, kernel_size=(3, 3), stride=(2, 2), groups=256, bias=False)
        (bn2): BatchNormAct2d(
          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
        (bn2): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
        (bn2): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (3): InvertedResidual(
        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
        (bn2): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (4): InvertedResidual(
        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
        (bn2): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (5): InvertedResidual(
        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
        (bn2): BatchNormAct2d(
          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
    )
    (block4): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)
        (bn2): BatchNormAct2d(
          768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(768, 32, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(32, 768, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (2): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (3): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (4): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (5): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (6): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (7): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
      (8): InvertedResidual(
        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
        (bn2): BatchNormAct2d(
          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNormAct2d(
          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
    )
    (mlp1): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1))
    (mlp2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (mlp3): Conv2d(160, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (head): LightHamHead(
    (upsample_x2): Upsample(scale_factor=2.0, mode='bicubic')
    (upsample_x4): Upsample(scale_factor=4.0, mode='bicubic')
    (squeeze): Conv2d(448, 256, kernel_size=(1, 1), stride=(1, 1))
    (hamburger): Hamburger(
      (hamburger): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): NMF2D()
        (3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (relu): ReLU(inplace=True)
    )
    (align): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (cls_seg): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))
  )
  (upsample_x4): Upsample(scale_factor=4.0, mode='bicubic')
  (sigmoid): Sigmoid()
)

次はモデルの出力とマスクの誤差を計算し、最適化で必要な勾配を発生させる損失関数と評価指標を作っていく。

損失関数と評価指標

画像分類などではクロスエントロピーが有名だが、セマンティックセグメンテーションでは領域の重なりを表現するものをよく使う。
今回はIoULossとFocalLossをあわせたFocalIoULossを作っていく。

FocalIoULoss

IoUは(領域の積)/(領域の和)で計算される、推論結果と正解との重なりの割合である。つまり、完全に重なっていれば1になり、全く重なっていなければ0になる。
損失関数は誤差が小さいほど小さくならなくてはならなため、IoULossは1-\texttt{IoU}で計算される。
FocalLossはクロスエントロピーをよりコントラストした関数で、背景が大きくフィットしすぎないように関数の勾配を平坦にするγをかけることにより、小さい物体のフィットに有利な効果を持つ。

[5] A visual equation for Intersection over Union (Jaccard Index).
http://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
Adrian Rosebrock

[6] Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár,
"Focal Loss for Dense Object Detection"
https://arxiv.org/abs/1708.02002

[5] IoUの視覚的な定義と [6] Focal Lossのカーブ

こちらの解説はFocalLossについて詳しく記されている。

https://qiita.com/agatan/items/53fe8d21f2147b0ac982

実装は以下のFocalLossとBCEDiceLossを参考にした。

https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch

# 評価指標
class IoUMetrics(nn.Module):
    def __init__(self):
        super().__init__()
        return None

    def forward(self, inputs, targets):
        smooth = 1  # 零除算回避用
	
	# テンソルを1列のベクトルに平坦化して計算すると楽
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
	# IoUの計算
        intersection = (inputs * targets).sum()
        union = (inputs + targets).sum() - intersection 
        return (intersection + smooth)/(union + smooth)


# 損失関数
class FocalIoULoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        return None

    def forward(self, inputs, targets):
        smooth = 1    
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
	# IoUの計算
        intersection = (inputs * targets).sum()
        union = (inputs + targets).sum() - intersection 
        iou = (intersection + smooth)/(union + smooth)

	# FocalLossの計算
        bec = F.binary_cross_entropy(inputs, targets, reduction='mean')
        becexp = torch.exp(-bec)
        focal_loss = self.alpha * (1-becexp)**self.gamma * bec

        return 1 - iou + focal_loss

以上で損失と評価指標である。次はこれらを使って最適化を行う。

モデルの最適化

最適化を楽に書くツールとしてpytorch lightningを使う。
まずモデルへデータをミニバッチ化して渡すDataModuleを作り、次にモデルの推論と損失の計算を行うLightningModuleを作る。
最後にこれらをインスタンス化して最適化を開始するTrainerへ渡せば、自動的に学習と指標の記録が開始される。

DataModule部分

DataModuleはデータセットインスタンスを格納し、trainデータとtestデータを管理する。

train.py
class LivecellDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_workers=4):
        super().__init__()  
	# input:
	#    batch_size: 学習時にモデルへ一括で渡すミニバッチの枚数
	#    num_workers: おまじない
	
	# 各データの場所をglobで表現
	# "*"の部分は任意の文字列が対応する
        self.images_train = "dataset/images_train/*.jpg"
        self.images_test = "dataset/images_test/*.jpg"
        self.masks_train = "dataset/masks_train/*.png"
        self.masks_test = "dataset/masks_test/*.png"
        self.batch_size = batch_size
        self.num_workers = num_workers
        return None

    def prepare_data(self):
        return None

    def setup(self, stage):
        # globでデータセットのpathのリストを一括で持ってくる
	# その後ソートして画像とマスクの順番が対応することを保証する
        images_train = glob.glob(self.images_train)
        images_train.sort()
        masks_train = glob.glob(self.masks_train)
        masks_train.sort()
        images_test = glob.glob(self.images_train)
        images_test.sort()
        masks_test = glob.glob(self.masks_train)
        masks_test.sort()
        
	# trainとtestのデータセットインスタンスを作成
        self.train_dataset = LivecellMosicDataset(images_train, masks_train, is_train=True)
        self.test_dataset = LivecellMosicDataset(images_test, masks_test, is_train=True)
        return None
	
    # 以下でtrainとtestのデータセットをデータローダへ渡す
    # データローダはデータセットから画像とマスクを呼び出し、ミニバッチ化して返してくれる
    # データセット側ではデータは[3,256,256][8,256,256]のサイズで返されるが
    # データローダはbatchsizeで纏めて[32,3,256,256][32,8,256,256]で返してくれる
    
    def train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )
        return train_dataloader
    
    def val_dataloader(self):
        test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )
        return test_dataloader

少し動かしてみると挙動を理解しやすい。

datamodule = LivecellDataModule(batch_size=10)
datamodule.setup(None)
dataloader = datamodule.train_dataloader()

for imgs, msks in dataloader:
    break
imgs  # tensor [10,3,256,256]
msks  # tensor [10,8,256,256]

モザイク拡張を受けた画像は次のようになっている。


LightningModule部分

LightningModuleはモデルと最適化設定を管理する。

train.py
class SegmentationModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
	# モデル、損失関数、評価指標をインスタンス化
        self.model = EfficientSegnext(device="cuda")
        self.loss = FocalIoULoss()
        self.metrics = IoUMetrics()
        return None
    
    def forward(self, x):
        # LightningModule本体を呼んだときの推論ロジック
        y = self.model(x)
        return y

    def training_step(self, batch, batch_idx):
        # 学習時のロジック
        img, msk = batch  # ミニバッチを取り出す
        out = self.forward(img)  # 推論
        loss = self.loss(out, msk)  # 損失計算

        self.log("train_loss", loss, logger=True, prog_bar=True)  # 経過を記録
        return {"loss": loss}  # 損失を返すことでtrainerが勝手に逆伝播してくれる
    
    def validation_step(self, batch, batch_idx):
        # 検証時のロジック
        img, msk = batch
        out = self.forward(img)
        m = self.metrics(out, msk)

        self.log("IoU", m, logger=True, on_epoch=True, prog_bar=True) 
        return {"metrics": m}
                
    def configure_optimizers(self):
        # 最適化設定
	maxepoch = 16
	# 正則化がついていて収束が高速安定なAdamWを使う
        optimizer = torch.optim.AdamW(  
            self.parameters(),  # モデルのパラメータがLightningModuleに格納されるのでこれを使う
            lr=0.001,  # 初期学習率
            weight_decay=0.01  # 正則化の大きさ
	    )
	# 学習率の減衰はコサインカーブで下げる
	# これにより学習率やepoch数で花瓶に過学習が起きることが防げる
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 
            optimizer,  # 上で定義した最適化手法インスタンス
            maxepoch,  # 最大epoch数 
            eta_min=0.00001  # 減衰の最小値
	    )
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

実行部分

train.py
def main():
    datamodule = LivecellDataModule(batch_size=BATCHSIZE) 
    model = SegmentationModel.load_from_checkpoint("model_efficientsegnext_16epoch.ckpt")

    pllogger_csv = pl.pytorch.loggers.CSVLogger(
        "./logs/",
	name="EfficientSegNext"  # ログを保存するディレクトリの名前 
	)
    trainer = pl.Trainer(
        logger=pllogger_csv,  # CSVファイルで最適化の記録を保存
        enable_checkpointing=True,  # 学習した重みを保存する
        check_val_every_n_epoch=1,  # 検証を行うインターバル
        accelerator="gpu",
        devices=1,
        max_epochs=16,
        )
	
    # 学習ループ実行
    trainer.fit( 
        model, 
        datamodule=datamodule
        )

    # 結果の検証 -> list[dict[str][float]]
    result = trainer.validate(
        model,
        datamodule=datamodule
        )
    print(result)

    return None


if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()

これによりモデルの最適化が周る。実際の結果を見てみると次のようになる。

root@d09aeb1b8aef:/workspace# python train_mosic.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | EfficientSegnext | 5.7 M 
1 | loss    | FocalIoULoss     | 0     
2 | metrics | IoUMetrics       | 0     
---------------------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.610    Total estimated model params size (MB)
Epoch 3: 100%|█████████████████████████████████████████| 470/470 [07:37<00:00,  1.03it/s, v_num=9, train_loss=0.277, IoU=0.719]`Trainer.fit` stopped: `max_epochs=4` reached.                                                                                 
Epoch 3: 100%|█████████████████████████████████████████| 470/470 [07:37<00:00,  1.03it/s, v_num=9, train_loss=0.277, IoU=0.719]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████| 470/470 [02:35<00:00,  3.02it/s]
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           IoU              0.7335798740386963
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

IoU = 0.719はちょっと低いが、16epochで軽量なモデルを使っているからだろうか...
実際に結果を可視化してみる。

vis.py
def mask_vis(img, msk, infer, n_cls, fname=""):
    plt.gca().axis("off")
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.imshow(torch.sum(img[0], dim=0), cmap="pink")
    plt.savefig(f"visualize_{fname}_img.png")
    for c in range(n_cls):
        plt.imshow(msk[0,c], cmap="hot")
        plt.savefig(f"visualize_{fname}_msk_{c+1}.png")
        plt.imshow(infer[0,c], cmap="hot")
        plt.savefig(f"visualize_{fname}_infer_{c+1}.png")
    plt.clf()
    plt.close()
    return None


def vis():
    # モデルの読み込み
    model = SegmentationModel.load_from_checkpoint("model_efficientsegnext_mosicfintuning_16epoch.ckpt")
    model = model.model.cuda()
    model.eval()
    
    # データモジュールの準備
    datamodule = LivecellDataModule(batch_size=1)
    datamodule.setup(None)
    dataloader = datamodule.val_dataloader()
    
    for img, msk in dataloader:
        break
    infer = model(img.cuda())
    infer = infer.detach().cpu()

    # 画像、マスク、推論結果の可視化
    mask_vis(img, msk, infer, 8)

    return None
    

if __name__ == "__main__":
    multiprocessing.freeze_support()
    vis()

上2行が正解マスク、下2行が推論結果

セグメンテーションとしては少しレベルが低いが、クラス分類的にはきれいに行っていると見ていい気がする。

モデルは非常に高速で、ResNet50-DeepLabv3pulsの1/2ほどの時間で学習できた。今回はEfficientNetv2-Sを一部削除して使っているため超高速だが、やはり特徴抽出の部分で厳しかったのかもしれない。余裕があったらEfficientNet以外にもConvNextやHRViTなどのバックボーンを使って試してみるのもいいかもしれない。

ちなみに、モザイク拡張を使わずに学習した場合はIoU = 0.836でそこそこ正確だったが、モザイク拡張を行ったテストデータに対しての精度は全くだめで、分類すらできていなかった。
これはおそらくデコード時のMLP層で大域的に特徴が有るか無いかに影響を受けて識別するようになってしまった説が有力だと思う。やはり画像に多くの特徴を持たせなければ外挿時の精度が落ちるという結論で正しい気がする。

おまけ: モデルのONNX化

モデルを外部のデバイスで使うときなどはONNXでやり取りするのが良い。

def saveonnx():
    model = SegmentationModel.load_from_checkpoint("model_effficientsegnext.ckpt") 
    model = model.model.cpu().eval()

    model(torch.ones(1,3,256,256))

    dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)  
    torch.onnx.export(
        model,
        dummy_input,
        "efficientsegnext.onnx",  
        export_params=True,  # store the trained parameter weights inside the model file 
        opset_version=16,    # the ONNX version to export the model to 
        do_constant_folding=True,  # whether to execute constant folding for optimization 
        input_names=['input_image'],   # the model's input names 
        output_names=['output_mask'], # the model's output names 
        dynamic_axes={
            'input_image' : {0 : 'batch_size'},
            'output_mask' : {0 : 'batch_size'},
        }) 

    return None

これでモデルがONNXの計算グラフに変換され、構造なども見れるようになる。
今回作成したEfficientSegnextは行列分解に漸近アルゴリズムを使っているため、図に載せなかった部分は非常に細長いモデル構造になっていた。研究の足しになる訳では無いが、眺めているとおもしろい発見があるのでいろいろなモデルを変換して見ることをおすすめする。

EfficientSegnextの計算グラフ(抜粋)

おまけ: ONNXの推論

変換したONNXはランタイムで推論できる。

def onnxload():
    import onnxruntime as rt
    import onnx

    datamodule = KvasirDataModule(batch_size=1, img_size=(256,256)) 
    datamodule.setup(None)
    img, msk = datamodule.test_dataset[1234]
    img = torch.unsqueeze(img, dim=0)
    img = img.numpy()

    sess = rt.InferenceSession(
        "efficientsegnext.onnx", 
        providers=rt.get_available_providers()
        )
    input_name = sess.get_inputs()[0].name
    output = sess.run(None, {input_name: img.astype(np.float32)})[0]
    
    import matplotlib.pyplot as plt
    plt.imshow(output[0,0], cmap="hot")
    plt.savefig(f"mask_onnxout.png")

    return None

よきセマンティックセグメンテーションライフを!

Discussion