Mosic Augmentationを使ったSemantic Segmentationモデル実装
初学者向きの日本語記事で細かい設計に踏み込んだ解説が少ないので書きます。
タスクに関わる部分は可能な限りtorch
とtorchvision
で完結しシンプルになるように心がけました。この記事で説明することは主に次の内容です。
- セグメンテーションタスクの概観
- データ作成
- データセットクラスの作成
- セグメンテーションモデルの作成
- 損失関数と評価指標の作成
- 学習ループ作成
- 学習と検証
作業環境
私はpyenvなどよりDocker上に実験環境を作る方が好みなので、以下のDockerfileを使ってコンテナで作業しているが、おそらくGoogle Colabなどを用いても同じことはできると思う。
Colabではpip install ...
の部分をipynb上にコピペすればライブラリを用意できる。
ARG PYTORCH="2.1.1"
ARG CUDA="12.1"
ARG CUDNN="8"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
FORCE_CUDA="1" \
DEBIAN_FRONTEND=noninteractive
# Install the required packages
RUN apt-get update
RUN apt-get install -y ffmpeg git ninja-build libsm6 libxext6 libglib2.0-0 libsm6 libxrender-dev libxext6 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN pip install lightning matplotlib seaborn onnx onnxruntime onnxoptimizer scikit-learn \
&& pip install timm torchinfo torchmetrics transformers polars
CMD ["/bin/bash"]
セマンティックセグメンテーション
最も基本的なタスクとして画像分類が有名だが、分類モデルは入力画像
今後混乱がないように書くと、PyTorchの画像系NNモデルは一般的に入力次元が[Batch(B
), Channel(C
), Height(H
), Width(W
)]となっている点に注意。Batchは学習時のミニバッチの枚数。
これから実行するモデルは、画像img
と訓練データ(マスク)msk
を受け取り、教師あり学習を行うことで、img
からマスクを推定する。まずはデータセットを作っていく。
データセットの作成
今回は細胞の輪郭から細胞の画素と種類を推定するタスクを実行する。
データセットは以下のLiveCell 2021のCOCO Datasetからマスクを取り出して、画像に整形して作ることにするが、形式さえ同じなら何を使ってもいい。
A172
, BT-474
, BV-2
, Huh7
, MCF7
, SH-SY5Y
, SkBr3
, SK-OV-3
の8クラスの細胞があり、それぞれ1つの画像の中に1クラスが写っている。
[1] Christoffer Edlund, Timothy R. Jackson, Nabeel Khalid, Nicola Bevan, Timothy Dale, Andreas Dengel, Sheraz Ahmed, Johan Trygg & Rickard Sjögren
"LIVECell—A large-scale dataset for label-free live cell segmentation"
https://www.nature.com/articles/s41592-021-01249-6
[1] 細胞画像の特徴空間上での分布
[1] マスクの例
このデータセットを整形したのがこちら。
整形内容
- 全画像を256×256で右上、左上、右下、左下部分をクロップ
- COCOアノテーションファイルのポリゴン頂点をグレースケール画像化
整形済み画像データとマスクデータ
左上からA172
, BT-474
, BV-2
, Huh7
、下段にMCF7
, SH-SY5Y
, SkBr3
, SK-OV-3
となっている。素人には判別がつかない。
マスクも同じクラスの並びにしている。それぞれのクラスのマスクは輝度値が10ずつ異なる。これによりプログラム部分で
自作データセットを作る場合はこのようにカラー画像と対応するグレースケールマスク画像を作ることになる。今回は同一データ内に異なるクラスのマスクが無い(後で工夫して混合する)が、一般的には混ざっていても問題ない。グレースケールでなくても、カラーパレットのRGB値を持っているなら原理的には同様の操作が可能なので、クラス数が255より多い場合などはカラーを使う手もある。
参考までに、他のデータセットとしてCOCO Stuffも併載する。
[2] Holger Caesar, Jasper Uijlings, Vittorio Ferrari,
"COCO-Stuff: Thing and Stuff Classes in Context"
https://arxiv.org/abs/1612.03716
COCO Stuffの画像とマスク
[2] COCO Stuffのクラス情報など
データセットを整形したら、次のようにデータを配置する。今回、trainとvalは両方trainに混ぜてしまい、testのみ隔離する。また、画像と対応するマスク画像は同じ名前にすると便利。
experiment/
└─ dataset/
... ├─ images_train/
├─ train_cls1_0000000_0.jpg
└─ train_cls1_0000000_1.jpg ...
├─ images_test/
├─ test_cls1_0000000_0.jpg
└─ test_cls1_0000000_1.jpg ...
├─ masks_train/
├─ train_cls1_0000000_0.png
└─ train_cls1_0000000_1.png ...
└─ images_test/
├─ test_cls1_0000000_0.png
└─ test_cls1_0000000_1.png ...
これで一旦準備ができたので、次はデータをPythonへ読み込むプログラム部分を作っていく。
データの読み込み
コードを書く前に、この記事で使用するライブラリ一覧を示す。また登場次第適宜紹介する。
-
import glob
: ファイル名を部分一致で検索しリストにして持ってくる -
import multiprocessing
: おまじない -
import warnings
: おまじない -
import matplotlib.pyplot as plt
: テンソルを画像化するとき使う -
import torch
: いつもの -
from torch import nn
: モデル構築に使う -
from torch.nn import functional as F
: テンソルに関数を適用するとき使う -
import torchvision
: 画像読み込みなどでつかう -
from torchvision.transforms import v2 as transforms
: 画像変換 -
from torchvision.transforms.v2 import functional as transformsF
: 画像変換関数 -
import lightning as pl
: pytorchの最適化ループを簡単にしてくれる -
import timm
: 事前学習済み画像分類NNの動物園
torch.utils.data.Dataset
部分
データの対をインデックスアクセスで取ってこれるクラスを作る。
記事ではクラス数などを直書きしているが、実際は引数などに持たせるように書き、クラスなどは他のファイルから呼び出している。
class LivecellMosicDataset(torch.utils.data.Dataset):
def __init__(self, images, masks, is_train=False):
# input:
# images: 画像のpathのリスト
# masks: マスク画像のpathのリスト
# is_train: bool trainデータの場合True
self.images = images
self.masks = masks
# trainデータならモザイク組み込みするインスタンスを作る
# testデータなら普通に読み込むインスタンスを作る
self.load = MosicImageLoader(images, masks) if is_train == True else ImageLoader(images, masks)
# 画像とマスク画像に対して画像処理(data augmentation)を行う
self.augmentation = MyResizeFlipCrop(size=IMGSIZE[0], normrize_img=True)
return None
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# 画像とマスクを読み込む
img, msk = self.load(idx)
# 画像とマスクに対して画像処理を行う
img, msk = self.augmentation(img, msk)
# マスク画像は輝度値10*NUM_CLASSなので、1刻みに変える
msk = transforms.functional.rgb_to_grayscale(msk)//10
# マスク画像[1,H,W]を要素{0, 1}の8クラスのマスクを表現するテンソル[8,H,W]へ変換する
msk = maskonehot(msk, 8)
return img.to(torch.float), msk.to(torch.float)
この中に出てくるImageLoader
, MosicImageLoader
, MyResizeFlipCrop
, maskonehot
はこれから作っていく。
概観としては、次のように動作して欲しいというものである。
dataset = LivecellMosicDataset([dataset/a.jpg ...], [dataset/a.png ...])
len(dataset) # = 5678
img, msk = dataset[123]
img # tensor [3,256,256]
msk # tensor [8,256,256]
ImageLoader
の作成
クラスImageLoader
は単純に画像とマスクのpathを受け取って画像とマスクを返す。
class ImageLoader():
def __init__(self, pathes_img, pathes_msk):
self.pathes_img = pathes_img
self.pathes_msk = pathes_msk
return None
def __call__(self, idx):
img = torchvision.io.read_image(self.pathes_img[idx]).to(torch.float)
msk = torchvision.io.read_image(self.pathes_msk[idx]).to(torch.float)
return img, msk
MosicImageLoader
部分
イレギュラーなので読み飛ばしても問題ない。
今回例で使うデータセットは1枚の画像に1クラスのマスクしか無いため、特徴の増幅のためMosaic data augmentationを行った。
[3] Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao,
"YOLOv4: Optimal Speed and Accuracy of Object Detectio"
https://arxiv.org/abs/2004.10934
[3] YOLOv4のデータ拡張
アクセスした画像のインデックスから4つの画像を決定的に持ってきて、結合したのちランダムにクロップすることで実装する。従って、出力画像はオリジナルの2倍の大きさになる点には注意。
class MosicImageLoader():
def __init__(self, pathes_img, pathes_msk):
# input:
# pathes_img: 画像のpathのリスト
# pathes_msk: マスク画像のpathのリスト
self.pathes_img = pathes_img
self.pathes_msk = pathes_msk
self.l = len(pathes_img)
# アクセスした画像以外のpathを決定する乱数を決定
self.rand = torch.randint(high=len(pathes_img), size=(3,))
return None
def mosicconcat(self, img_list):
# input List[tensor[C,H,W] × 4]
img1 = torch.cat([img_list[0], img_list[1]], dim=2)
img2 = torch.cat([img_list[2], img_list[3]], dim=2)
img = torch.cat([img1, img2], dim=1)
return img
def __call__(self, idx):
img = self.mosicconcat([
torchvision.io.read_image(self.pathes_img[idx]),
torchvision.io.read_image(self.pathes_img[(idx+self.rand[0])%self.l]),
torchvision.io.read_image(self.pathes_img[(idx+self.rand[1])%self.l]),
torchvision.io.read_image(self.pathes_img[(idx+self.rand[2])%self.l]),
]).to(torch.float)
msk = self.mosicconcat([
torchvision.io.read_image(self.pathes_msk[idx]),
torchvision.io.read_image(self.pathes_msk[(idx+self.rand[0])%self.l]),
torchvision.io.read_image(self.pathes_msk[(idx+self.rand[1])%self.l]),
torchvision.io.read_image(self.pathes_msk[(idx+self.rand[2])%self.l]),
]).to(torch.float)
return img, msk
MyResizeFlipCrop
部分
torchvision.transforms
の変形はマスク画像に同じ変形を適用できない。(できるらしいがうまく行かなかった)
そのため、入力された画像とマスクに同じ変形を行うクラスを作成する。
class MyResizeFlipCrop():
def __init__(self, size=256, clopscale=0.5, hflip=0.5, vflip=0.5, normrize_img=True):
# input:
# size: 変換後出力される画像のサイズ
# clopscale: 元画像に対してどの割合でクロップするか
# hflip, vflip: 画像の反転確率
# normrize_img: 画像を標準化するかどうか(基本的にはする)
self.size = size
self.scale = clopscale
self.hfp = hflip
self.vfp = vflip
self.normrize = Transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) if normrize_img == True else nn.Identity()
return None
def __call__(self, img, msk):
# input [3,H,W], [C,H,W]
scale = self.scale
y, x = torch.randint(high=int(self.size/scale-self.size-1), size=(2,))
img = transformsF.resize(
img,
int(self.size/scale),
interpolation=torchvision.transforms.InterpolationMode.BICUBIC)[:,x:x+self.size, y:y+self.size]
img = self.normrize(img)
msk = transformsF.resize(
msk,
int(self.size/scale),
interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT)[:,x:x+self.size, y:y+self.size]
if torch.rand(1) < self.hfp:
img = transformsF.horizontal_flip(img)
msk = transformsF.horizontal_flip(msk)
if torch.rand(1) < self.vfp:
img = transformsF.vertical_flip(img)
msk = transformsF.vertical_flip(msk)
return img, msk
maskonehot
部分
マスク画像を、モデル出力と比較するためのテンソルに変換する。
この記事ではグレースケール画像のマスク画像(成分は0~255)を8クラスのone-hotなテンソルにする。
以下の記事は非常に具体的な図が載っていて理解の助けになると思う。
def maskonehot(msk, n_cls: int):
# input [1,H,W] -> output [CLS,H,W]
# f: \mathbb{N}^{1×H×W} -> [0,1]^{CLS×H×W}
return torch.concat([torch.where(msk == c+1, 1, 0) for c in range(n_cls)], dim=0)
mskimg = torch.zeros(1,256,256)
msktensor = maskonehot(mskimg, 8)
msktensor # tensor [8,256,256]
以上によりデータセットの準備は完了した。次は学習モデルについて見ていく。
セグメンテーションモデルの作成
普通は色々なライブラリから事前学習済みモデルを持ってくるのが効率がいいが、この記事では学習のためプログラム本体を書いて行く。
方針としては、timm
からweightごとEfficientNetv2もモデルをビルドし、これをバックボーンとして3種類の解像度の特徴マップを引っこ抜いてきて、LightHamの分類器に入力する形にする。
つまりEfficientSegNextを作る。
Mingxing Tan, Quoc V. Le,
"EfficientNetV2: Smaller Models and Faster Training"
https://arxiv.org/abs/2104.00298
Zhengyang Geng, Meng-Hao Guo, Hongxu Chen, Xia Li, Ke Wei, Zhouchen Lin,
"Is Attention Better Than Matrix Decomposition?"
https://arxiv.org/abs/2109.04553
[4] Meng-Hao Guo, Cheng-Ze Lu, Qibin Hou, Zhengning Liu, Ming-Ming Cheng, Shi-Min Hu
"SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation"
https://arxiv.org/abs/2209.08575
SegNextの解説は以前記事にしたので詳しくは以下に譲る
以下の図の(a)と(c)を融合した形をつくることで、EfficientNetv2の速度性能と行列分解デコーダの速度性能が両方備わり最強に見える。
[4] エンコーダデコーダの比較
モデルの詳細な計算処理などは深く触れないが、モジュール単位で何がやり取りされているかにのみ注目する。
EfficientNetv2 Backbone
class EfficientNetBackbone(nn.Module):
def __init__(self):
super().__init__()
# ここでEfficientNetv2をビルド
model = timm.create_model(
"tf_efficientnetv2_s.in21k_ft_in1k",
pretrained=True,
)
# 以下でモデルを分解
self.stem = nn.Sequential(
model.conv_stem,
model.bn1
)
self.block0 = model.blocks[0]
self.block1 = model.blocks[1]
self.block2 = model.blocks[2]
self.block3 = model.blocks[3]
self.block4 = model.blocks[4]
# 256×256入力ではblock5に殆ど位置情報が残らないため、断腸の思いで捨てる
# self.block5 = model.blocks[5]
# channel数をデコーダと合わせるためのMLP層
self.mlp1 = nn.Conv2d( 48, 64, 1)
self.mlp2 = nn.Conv2d( 64, 128, 1)
self.mlp3 = nn.Conv2d(160, 256, 1)
return None
def forward(self, x):
# input:
# x: 画像入力 [B,3,H,W]
x = self.stem(x)
x = self.block0(x) # output [B, 24, H/2, W/2]
x = self.block1(x) # output [B, 48, H/4, W/4]
x1 = self.mlp1(x)
x = self.block2(x) # output [B, 64, H/8, W/8]
x2 = self.mlp2(x)
x = self.block3(x) # output [B,128,H/16,W/16]
x = self.block4(x) # output [B,160,H/16,W/16]
x3 = self.mlp3(x)
# x = self.block5(x) # (output [B,256,H/32,W/32])
return x1, x2, x3 # HW/4, HW/8, HW/16
これにより解像度がオリジナルの1/4, 1/8, 1/16の特徴マップを出力できるので、これをデコーダへ入力する。
LightHam Decoder
入力された3つの特徴マップの大きさを揃えて結合し、これをMLPと行列分解で特徴を分離しクラス数のテンソルに変換する。
最後に、空間方向を4倍してオリジナルのマスクテンソルの大きさと揃えて、Sigmoid関数で[0,1]に変換すれば出力形が完成する。
行列分解部分+MLP部分
class NMF2D(nn.Module):
def __init__(self, device="cuda"):
super().__init__()
self.spatial = True
self.S = 1
self.D = 512
self.R = 64
self.train_steps = 6
self.eval_steps = 7
self.inv_t = 1 # default 100
self.eta = 0.9
self.rand_init = True
# Use for device verification such as `self.tensordevice.device == torch.device('cpu')`
self.tensordevice = torch.device(device) # nn.Parameter(torch.empty(0))
return None
def _build_bases(self, B, S, D, R):
# _MatrixDecomposition2DBase default empty
if self.tensordevice == torch.device('cuda'):
bases = torch.rand((B * S, D, R)).cuda()
else:
print("cpu mode")
bases = torch.rand((B * S, D, R))
bases = F.normalize(bases, dim=1)
return bases
def local_step(self, x, bases, coef):
if self.tensordevice == torch.device('cuda'):
bases = bases.cuda()
else:
print("cpu mode")
# _MatrixDecomposition2DBase default empty
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
numerator = torch.bmm(x, coef)
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
def local_inference(self, x, bases):
if self.tensordevice == torch.device('cuda'):
bases = bases.cuda()
else:
print("cpu mode")
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.inv_t * coef, dim=-1)
steps = self.train_steps if self.training else self.eval_steps
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
def compute_coef(self, x, bases, coef):
if self.tensordevice == torch.device('cuda'):
bases = bases.cuda()
else:
print("cpu mode")
# _MatrixDecomposition2DBase default empty
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
def forward(self, x, return_bases=False):
B, C, H, W = x.shape
# (B, C, H, W) -> (B * S, D, N)
if self.spatial:
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
else:
D = H * W
N = C // self.S
x = x.view(B * self.S, N, D).transpose(1, 2)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R)
else:
bases = self.bases.repeat(B, 1, 1)
bases, coef = self.local_inference(x, bases)
# (B * S, N, R)
coef = self.compute_coef(x, bases, coef)
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (B * S, D, N) -> (B, C, H, W)
if self.spatial:
x = x.view(B, C, H, W)
else:
x = x.transpose(1, 2).view(B, C, H, W)
# (B * H, D, R) -> (B, H, N, D)
bases = bases.view(B, self.S, D, self.R)
return x
class Hamburger(nn.Module):
def __init__(self, ham_channels, device="cuda"):
super().__init__()
self.hamburger = nn.Sequential(
nn.Conv2d(ham_channels, ham_channels, 1),
nn.ReLU(inplace=True),
NMF2D(device=device),
nn.Conv2d(ham_channels, ham_channels, 1)
)
self.relu = nn.ReLU(inplace=True)
return None
def forward(self, x):
ham = self.hamburger(x)
ham = self.relu(x + ham)
return ham
class LightHamHead(nn.Module):
def __init__(self, in_channels=[128, 320, 512], ham_channels=512, channels=512, num_classes=150, device="cuda"):
super().__init__()
input_chs = sum(in_channels)
output_chs = channels
self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bicubic')
self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bicubic')
self.squeeze = nn.Conv2d(input_chs, ham_channels, 1)
self.hamburger = Hamburger(ham_channels, device=device)
self.align = nn.Conv2d(ham_channels, output_chs, 1)
self.cls_seg = nn.Conv2d(output_chs, num_classes, 1)
return None
def forward(self, p1, p2, p3):
inputs = torch.cat([p1, self.upsample_x2(p2), self.upsample_x4(p3)], dim=1)
x = self.squeeze(inputs)
x = self.hamburger(x)
output = self.align(x)
output = self.cls_seg(output)
return output
class EfficientSegnext(nn.Module):
def __init__(self, num_class=8, device="cuda"):
super().__init__()
# バックボーンとデコーダを結合する
self.backbone = EfficientNetBackbone()
self.head = LightHamHead(
in_channels=[64,128,256],
ham_channels=256,
channels=256,
num_classes=num_class,
device=device
)
# テンソルのサイズをオリジナルの大きさに戻す
self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bicubic')
# 値を[0,1]にする
self.sigmoid = nn.Sigmoid()
return None
def forward(self, x):
p1, p2, p3 = self.backbone(x)
out = self.head(p1, p2, p3)
out = self.upsample_x4(out)
out = self.sigmoid(out)
return out
これでモデルは完成である。バックボーンのEfficientNetv2は学習済みの重みを持っているので、ある程度セグメンテーションの学習時間を短縮できる。
このモデルは次のように使える。
model = EfficientSegnext(device="cpu")
model # モデルの全体像は下の折りたたみ欄に記載
out = model(torch.zeros(1,3,256,256))
out.shape # tensor [1,8,256,256]
モデルの全体像
EfficientSegnext(
(backbone): EfficientNetBackbone(
(stem): Sequential(
(0): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
(1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
)
(block0): Sequential(
(0): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(drop_path): Identity()
)
(1): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(drop_path): Identity()
)
)
(block1): Sequential(
(0): EdgeResidual(
(conv_exp): Conv2dSame(24, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNormAct2d(
96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(1): EdgeResidual(
(conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(2): EdgeResidual(
(conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(3): EdgeResidual(
(conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
)
(block2): Sequential(
(0): EdgeResidual(
(conv_exp): Conv2dSame(48, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNormAct2d(
192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(1): EdgeResidual(
(conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(2): EdgeResidual(
(conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(3): EdgeResidual(
(conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): Identity()
(conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNormAct2d(
64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
)
(block3): Sequential(
(0): InvertedResidual(
(conv_pw): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2dSame(256, 256, kernel_size=(3, 3), stride=(2, 2), groups=256, bias=False)
(bn2): BatchNormAct2d(
256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(1): InvertedResidual(
(conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
(bn2): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(2): InvertedResidual(
(conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
(bn2): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(3): InvertedResidual(
(conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
(bn2): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(4): InvertedResidual(
(conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
(bn2): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(5): InvertedResidual(
(conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
(bn2): BatchNormAct2d(
512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
)
(block4): Sequential(
(0): InvertedResidual(
(conv_pw): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)
(bn2): BatchNormAct2d(
768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(768, 32, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(32, 768, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(1): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(2): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(3): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(4): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(5): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(6): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(7): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(8): InvertedResidual(
(conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(bn2): BatchNormAct2d(
960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(se): SqueezeExcite(
(conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))
(act1): SiLU(inplace=True)
(conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNormAct2d(
160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
)
(mlp1): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1))
(mlp2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
(mlp3): Conv2d(160, 256, kernel_size=(1, 1), stride=(1, 1))
)
(head): LightHamHead(
(upsample_x2): Upsample(scale_factor=2.0, mode='bicubic')
(upsample_x4): Upsample(scale_factor=4.0, mode='bicubic')
(squeeze): Conv2d(448, 256, kernel_size=(1, 1), stride=(1, 1))
(hamburger): Hamburger(
(hamburger): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU(inplace=True)
(2): NMF2D()
(3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(relu): ReLU(inplace=True)
)
(align): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(cls_seg): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))
)
(upsample_x4): Upsample(scale_factor=4.0, mode='bicubic')
(sigmoid): Sigmoid()
)
次はモデルの出力とマスクの誤差を計算し、最適化で必要な勾配を発生させる損失関数と評価指標を作っていく。
損失関数と評価指標
画像分類などではクロスエントロピーが有名だが、セマンティックセグメンテーションでは領域の重なりを表現するものをよく使う。
今回はIoULossとFocalLossをあわせたFocalIoULossを作っていく。
FocalIoULoss
IoUは(領域の積)/(領域の和)で計算される、推論結果と正解との重なりの割合である。つまり、完全に重なっていれば1になり、全く重なっていなければ0になる。
損失関数は誤差が小さいほど小さくならなくてはならなため、IoULossは
FocalLossはクロスエントロピーをよりコントラストした関数で、背景が大きくフィットしすぎないように関数の勾配を平坦にするγをかけることにより、小さい物体のフィットに有利な効果を持つ。
[5] A visual equation for Intersection over Union (Jaccard Index).
http://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
Adrian Rosebrock
[6] Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár,
"Focal Loss for Dense Object Detection"
https://arxiv.org/abs/1708.02002
[5] IoUの視覚的な定義と [6] Focal Lossのカーブ
こちらの解説はFocalLossについて詳しく記されている。
実装は以下のFocalLossとBCEDiceLossを参考にした。
# 評価指標
class IoUMetrics(nn.Module):
def __init__(self):
super().__init__()
return None
def forward(self, inputs, targets):
smooth = 1 # 零除算回避用
# テンソルを1列のベクトルに平坦化して計算すると楽
inputs = inputs.view(-1)
targets = targets.view(-1)
# IoUの計算
intersection = (inputs * targets).sum()
union = (inputs + targets).sum() - intersection
return (intersection + smooth)/(union + smooth)
# 損失関数
class FocalIoULoss(nn.Module):
def __init__(self, alpha=0.8, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
return None
def forward(self, inputs, targets):
smooth = 1
inputs = inputs.view(-1)
targets = targets.view(-1)
# IoUの計算
intersection = (inputs * targets).sum()
union = (inputs + targets).sum() - intersection
iou = (intersection + smooth)/(union + smooth)
# FocalLossの計算
bec = F.binary_cross_entropy(inputs, targets, reduction='mean')
becexp = torch.exp(-bec)
focal_loss = self.alpha * (1-becexp)**self.gamma * bec
return 1 - iou + focal_loss
以上で損失と評価指標である。次はこれらを使って最適化を行う。
モデルの最適化
最適化を楽に書くツールとしてpytorch lightning
を使う。
まずモデルへデータをミニバッチ化して渡すDataModule
を作り、次にモデルの推論と損失の計算を行うLightningModule
を作る。
最後にこれらをインスタンス化して最適化を開始するTrainer
へ渡せば、自動的に学習と指標の記録が開始される。
DataModule
部分
DataModule
はデータセットインスタンスを格納し、trainデータとtestデータを管理する。
class LivecellDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32, num_workers=4):
super().__init__()
# input:
# batch_size: 学習時にモデルへ一括で渡すミニバッチの枚数
# num_workers: おまじない
# 各データの場所をglobで表現
# "*"の部分は任意の文字列が対応する
self.images_train = "dataset/images_train/*.jpg"
self.images_test = "dataset/images_test/*.jpg"
self.masks_train = "dataset/masks_train/*.png"
self.masks_test = "dataset/masks_test/*.png"
self.batch_size = batch_size
self.num_workers = num_workers
return None
def prepare_data(self):
return None
def setup(self, stage):
# globでデータセットのpathのリストを一括で持ってくる
# その後ソートして画像とマスクの順番が対応することを保証する
images_train = glob.glob(self.images_train)
images_train.sort()
masks_train = glob.glob(self.masks_train)
masks_train.sort()
images_test = glob.glob(self.images_train)
images_test.sort()
masks_test = glob.glob(self.masks_train)
masks_test.sort()
# trainとtestのデータセットインスタンスを作成
self.train_dataset = LivecellMosicDataset(images_train, masks_train, is_train=True)
self.test_dataset = LivecellMosicDataset(images_test, masks_test, is_train=True)
return None
# 以下でtrainとtestのデータセットをデータローダへ渡す
# データローダはデータセットから画像とマスクを呼び出し、ミニバッチ化して返してくれる
# データセット側ではデータは[3,256,256][8,256,256]のサイズで返されるが
# データローダはbatchsizeで纏めて[32,3,256,256][32,8,256,256]で返してくれる
def train_dataloader(self):
train_dataloader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)
return train_dataloader
def val_dataloader(self):
test_dataloader = torch.utils.data.DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
return test_dataloader
少し動かしてみると挙動を理解しやすい。
datamodule = LivecellDataModule(batch_size=10)
datamodule.setup(None)
dataloader = datamodule.train_dataloader()
for imgs, msks in dataloader:
break
imgs # tensor [10,3,256,256]
msks # tensor [10,8,256,256]
モザイク拡張を受けた画像は次のようになっている。
LightningModule
部分
LightningModule
はモデルと最適化設定を管理する。
class SegmentationModel(pl.LightningModule):
def __init__(self):
super().__init__()
# モデル、損失関数、評価指標をインスタンス化
self.model = EfficientSegnext(device="cuda")
self.loss = FocalIoULoss()
self.metrics = IoUMetrics()
return None
def forward(self, x):
# LightningModule本体を呼んだときの推論ロジック
y = self.model(x)
return y
def training_step(self, batch, batch_idx):
# 学習時のロジック
img, msk = batch # ミニバッチを取り出す
out = self.forward(img) # 推論
loss = self.loss(out, msk) # 損失計算
self.log("train_loss", loss, logger=True, prog_bar=True) # 経過を記録
return {"loss": loss} # 損失を返すことでtrainerが勝手に逆伝播してくれる
def validation_step(self, batch, batch_idx):
# 検証時のロジック
img, msk = batch
out = self.forward(img)
m = self.metrics(out, msk)
self.log("IoU", m, logger=True, on_epoch=True, prog_bar=True)
return {"metrics": m}
def configure_optimizers(self):
# 最適化設定
maxepoch = 16
# 正則化がついていて収束が高速安定なAdamWを使う
optimizer = torch.optim.AdamW(
self.parameters(), # モデルのパラメータがLightningModuleに格納されるのでこれを使う
lr=0.001, # 初期学習率
weight_decay=0.01 # 正則化の大きさ
)
# 学習率の減衰はコサインカーブで下げる
# これにより学習率やepoch数で花瓶に過学習が起きることが防げる
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, # 上で定義した最適化手法インスタンス
maxepoch, # 最大epoch数
eta_min=0.00001 # 減衰の最小値
)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
実行部分
def main():
datamodule = LivecellDataModule(batch_size=BATCHSIZE)
model = SegmentationModel.load_from_checkpoint("model_efficientsegnext_16epoch.ckpt")
pllogger_csv = pl.pytorch.loggers.CSVLogger(
"./logs/",
name="EfficientSegNext" # ログを保存するディレクトリの名前
)
trainer = pl.Trainer(
logger=pllogger_csv, # CSVファイルで最適化の記録を保存
enable_checkpointing=True, # 学習した重みを保存する
check_val_every_n_epoch=1, # 検証を行うインターバル
accelerator="gpu",
devices=1,
max_epochs=16,
)
# 学習ループ実行
trainer.fit(
model,
datamodule=datamodule
)
# 結果の検証 -> list[dict[str][float]]
result = trainer.validate(
model,
datamodule=datamodule
)
print(result)
return None
if __name__ == "__main__":
multiprocessing.freeze_support()
main()
これによりモデルの最適化が周る。実際の結果を見てみると次のようになる。
root@d09aeb1b8aef:/workspace# python train_mosic.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
---------------------------------------------
0 | model | EfficientSegnext | 5.7 M
1 | loss | FocalIoULoss | 0
2 | metrics | IoUMetrics | 0
---------------------------------------------
5.7 M Trainable params
0 Non-trainable params
5.7 M Total params
22.610 Total estimated model params size (MB)
Epoch 3: 100%|█████████████████████████████████████████| 470/470 [07:37<00:00, 1.03it/s, v_num=9, train_loss=0.277, IoU=0.719]`Trainer.fit` stopped: `max_epochs=4` reached.
Epoch 3: 100%|█████████████████████████████████████████| 470/470 [07:37<00:00, 1.03it/s, v_num=9, train_loss=0.277, IoU=0.719]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████| 470/470 [02:35<00:00, 3.02it/s]
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Validate metric DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
IoU 0.7335798740386963
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
IoU = 0.719はちょっと低いが、16epochで軽量なモデルを使っているからだろうか...
実際に結果を可視化してみる。
def mask_vis(img, msk, infer, n_cls, fname=""):
plt.gca().axis("off")
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.imshow(torch.sum(img[0], dim=0), cmap="pink")
plt.savefig(f"visualize_{fname}_img.png")
for c in range(n_cls):
plt.imshow(msk[0,c], cmap="hot")
plt.savefig(f"visualize_{fname}_msk_{c+1}.png")
plt.imshow(infer[0,c], cmap="hot")
plt.savefig(f"visualize_{fname}_infer_{c+1}.png")
plt.clf()
plt.close()
return None
def vis():
# モデルの読み込み
model = SegmentationModel.load_from_checkpoint("model_efficientsegnext_mosicfintuning_16epoch.ckpt")
model = model.model.cuda()
model.eval()
# データモジュールの準備
datamodule = LivecellDataModule(batch_size=1)
datamodule.setup(None)
dataloader = datamodule.val_dataloader()
for img, msk in dataloader:
break
infer = model(img.cuda())
infer = infer.detach().cpu()
# 画像、マスク、推論結果の可視化
mask_vis(img, msk, infer, 8)
return None
if __name__ == "__main__":
multiprocessing.freeze_support()
vis()
上2行が正解マスク、下2行が推論結果
セグメンテーションとしては少しレベルが低いが、クラス分類的にはきれいに行っていると見ていい気がする。
モデルは非常に高速で、ResNet50-DeepLabv3pulsの1/2ほどの時間で学習できた。今回はEfficientNetv2-Sを一部削除して使っているため超高速だが、やはり特徴抽出の部分で厳しかったのかもしれない。余裕があったらEfficientNet以外にもConvNextやHRViTなどのバックボーンを使って試してみるのもいいかもしれない。
ちなみに、モザイク拡張を使わずに学習した場合はIoU = 0.836でそこそこ正確だったが、モザイク拡張を行ったテストデータに対しての精度は全くだめで、分類すらできていなかった。
これはおそらくデコード時のMLP層で大域的に特徴が有るか無いかに影響を受けて識別するようになってしまった説が有力だと思う。やはり画像に多くの特徴を持たせなければ外挿時の精度が落ちるという結論で正しい気がする。
おまけ: モデルのONNX化
モデルを外部のデバイスで使うときなどはONNXでやり取りするのが良い。
def saveonnx():
model = SegmentationModel.load_from_checkpoint("model_effficientsegnext.ckpt")
model = model.model.cpu().eval()
model(torch.ones(1,3,256,256))
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
torch.onnx.export(
model,
dummy_input,
"efficientsegnext.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=16, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input_image'], # the model's input names
output_names=['output_mask'], # the model's output names
dynamic_axes={
'input_image' : {0 : 'batch_size'},
'output_mask' : {0 : 'batch_size'},
})
return None
これでモデルがONNXの計算グラフに変換され、構造なども見れるようになる。
今回作成したEfficientSegnextは行列分解に漸近アルゴリズムを使っているため、図に載せなかった部分は非常に細長いモデル構造になっていた。研究の足しになる訳では無いが、眺めているとおもしろい発見があるのでいろいろなモデルを変換して見ることをおすすめする。
EfficientSegnextの計算グラフ(抜粋)
おまけ: ONNXの推論
変換したONNXはランタイムで推論できる。
def onnxload():
import onnxruntime as rt
import onnx
datamodule = KvasirDataModule(batch_size=1, img_size=(256,256))
datamodule.setup(None)
img, msk = datamodule.test_dataset[1234]
img = torch.unsqueeze(img, dim=0)
img = img.numpy()
sess = rt.InferenceSession(
"efficientsegnext.onnx",
providers=rt.get_available_providers()
)
input_name = sess.get_inputs()[0].name
output = sess.run(None, {input_name: img.astype(np.float32)})[0]
import matplotlib.pyplot as plt
plt.imshow(output[0,0], cmap="hot")
plt.savefig(f"mask_onnxout.png")
return None
よきセマンティックセグメンテーションライフを!
Discussion