Segmentation Models Pytorch の BackBone をカスタムする方法
Segmentation Models Pytorch の バックボーン(Backbone) を設定したいように変更する時の方法と注意点
はじめに
ディープラーニングのための自動微分をしてくれる PyTroch において、セグメンテーションのタスクのモデルを簡単に提供してくれるのが Segmentation Models Pytorchです。
モデル作成の初学者から Kaggle などのコンペ、業務でも仕様できるくらい便利なライブラリーです。
もう既にバックボーンとなるエンコーダーは自由に選べるほど豊富にあります。
しかしながら、最新のモデルや3Dカスタムなどをする時に簡単にカスタムする方法がわからなかったので備忘録のつもりで例を記載しております。
実装
Python 実装を記載します。
基本のおさらい
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=1, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=3, # model output channels (number of classes in your dataset)
)
ドキュメント通りで上記のように使用できます。
既存エンコーダー
ここで既に提供されているエンコーダーをカスタムする場合は、以下のドキュメントから encoder_name
を変更するだけで可能です。
Mix Visual Transformer
を仕様したい場合は以下です。
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="mit_b4",
encoder_weights="imagenet",
classes=3,
)
Timm のエンコーダー
さらに、PyTorch モデルを提供してくれる Pytorch Image Models 通称 Timmからも使用することができます。
これでかなりのモデルは仕様できるのではないかと思います。
以下のドキュメントからモデルを選択できます。500以上もの数がありますね。
例えば、Timm のエンコーダーで MaxViT
を使いたいと思った場合です。
Timm のモデルは encoder_name
に tu-
+ timm の名前
にすると呼び込めます。
maxvit_small_tf_224
の時は、 tu-maxvit_small_tf_224
を指定します。
model = smp.Unet(
encoder_name="tu-maxvit_small_tf_224",
encoder_weights="imagenet",
classes=3,
)
カスタムモデル
本題のバックボーンのカスタム方法と例です。
基本的な方法
Segmentation Models PyTorch のライブラリーでエンコーダーを追加したい場合は、以下の2つを実装すれば良いです。
-
EncoderMixin
クラスの継承 -
smp.encoders.encoders
にクラスと名前を登録
例えば、maxvit
を入れたい時
ちょっとした豆知識ですが、timm では list_models
の関数で実装している名前を部分一致などで検索できます。
import timm
timm.list_models('maxvit*')
まずは簡単に、 Timm の maxvit_small_tf_224
モデルで拡張します。
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import timm
class MiddleSkipConnectionEncoder(nn.Module, EncoderMixin):
def __init__(self, **kwargs):
super().__init__()
self._out_channels = [1, 64, 96, 192, 384, 768]
self._depth: int = 5
self._in_channels: int = 1
self.backbone = timm.create_model('maxvit_small_tf_224',
pretrained=False,
features_only=True,
in_chans=self._in_channels)
def forward(self, x: torch.Tensor):
features = self.backbone(x)
return [x[:,0,:,:], *features]
こちらで 1 のクラス継承の手順です。
smp.encoders.encoders["maxvit224_encoder"] = {
"encoder": MiddleSkipConnectionEncoder,
'params' : {},
'pretrained_settings': {},
}
こちらで 2 のクラスと名前の登録です。
では、動作確認をします。
model = smp.Unet(encoder_name='maxvit224_encoder', encoder_weights=None ,classes=1)
# forward
inputs = torch.ones((1, 1, 224, 224)) # B, C, H, W
outputs = model(inputs)
# outputs
outputs.shape
一応、確認のために tu-maxvit_small_tf_224
の場合も載せておきます。
model = smp.Unet(
encoder_name="tu-maxvit_small_tf_224",
encoder_weights="imagenet",
classes=3,
)
独自モデル
では、NextViT を持って来たいときは、どうしたら良いでしょうか?
以下の公式のコードを参考にします。
エンコーダーの実装
ほとんど写しですが、一部は1つにまとめられるようにしています。
# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from torch.nn.modules.batchnorm import _BatchNorm
from timm.models.layers import DropPath, trunc_normal_
NORM_EPS = 1e-5
def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
""" Merge pre BN to reduce inference runtime.
"""
weight = module.weight.data
if module.bias is None:
zeros = torch.zeros(module.out_channels, device=weight.device).type(weight.type())
module.bias = nn.Parameter(zeros)
bias = module.bias.data
if pre_bn_2 is None:
assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"
scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
extra_weight = scale_invstd * pre_bn_1.weight
extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
else:
assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"
assert pre_bn_2.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
assert pre_bn_2.affine is True, "Unsupport bn_module.affine is False"
scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)
extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
extra_bias = scale_invstd_2 * pre_bn_2.weight *(pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + pre_bn_2.bias
if isinstance(module, nn.Linear):
extra_bias = weight @ extra_bias
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
elif isinstance(module, nn.Conv2d):
assert weight.shape[2] == 1 and weight.shape[3] == 1
weight = weight.reshape(weight.shape[0], weight.shape[1])
extra_bias = weight @ extra_bias
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
bias.add_(extra_bias)
module.weight.data = weight
module.bias.data = bias
class ConvBNReLU(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=1, groups=groups, bias=False)
self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class PatchEmbed(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride=1):
super(PatchEmbed, self).__init__()
norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
if stride == 2:
self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_channels)
elif in_channels != out_channels:
self.avgpool = nn.Identity()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_channels)
else:
self.avgpool = nn.Identity()
self.conv = nn.Identity()
self.norm = nn.Identity()
def forward(self, x):
return self.norm(self.conv(self.avgpool(x)))
class MHCA(nn.Module):
"""
Multi-Head Convolutional Attention
"""
def __init__(self, out_channels, head_dim):
super(MHCA, self).__init__()
norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
padding=1, groups=out_channels // head_dim, bias=False)
self.norm = norm_layer(out_channels)
self.act = nn.ReLU(inplace=True)
self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)
def forward(self, x):
out = self.group_conv3x3(x)
out = self.norm(out)
out = self.act(out)
out = self.projection(out)
return out
class Mlp(nn.Module):
def __init__(self, in_features, out_features=None, mlp_ratio=None, drop=0., bias=True):
super().__init__()
out_features = out_features or in_features
hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=1, bias=bias)
self.act = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(hidden_dim, out_features, kernel_size=1, bias=bias)
self.drop = nn.Dropout(drop)
def merge_bn(self, pre_norm):
merge_pre_bn(self.conv1, pre_norm)
def forward(self, x):
x = self.conv1(x)
x = self.act(x)
x = self.drop(x)
x = self.conv2(x)
x = self.drop(x)
return x
class NCB(nn.Module):
"""
Next Convolution Block
"""
def __init__(self, in_channels, out_channels, stride=1, path_dropout=0,
drop=0, head_dim=32, mlp_ratio=3):
super(NCB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
assert out_channels % head_dim == 0
self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
self.mhca = MHCA(out_channels, head_dim)
self.attention_path_dropout = DropPath(path_dropout)
self.norm = norm_layer(out_channels)
self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
self.mlp_path_dropout = DropPath(path_dropout)
self.is_bn_merged = False
def merge_bn(self):
if not self.is_bn_merged:
self.mlp.merge_bn(self.norm)
self.is_bn_merged = True
def forward(self, x):
x = self.patch_embed(x)
x = x + self.attention_path_dropout(self.mhca(x))
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
out = self.norm(x)
else:
out = x
x = x + self.mlp_path_dropout(self.mlp(out))
return x
class E_MHSA(nn.Module):
"""
Efficient Multi-Head Self Attention
"""
def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
attn_drop=0, proj_drop=0., sr_ratio=1):
super().__init__()
self.dim = dim
self.out_dim = out_dim if out_dim is not None else dim
self.num_heads = self.dim // head_dim
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
self.proj = nn.Linear(self.dim, self.out_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
self.N_ratio = sr_ratio ** 2
if sr_ratio > 1:
self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
self.is_bn_merged = False
def merge_bn(self, pre_bn):
merge_pre_bn(self.q, pre_bn)
if self.sr_ratio > 1:
merge_pre_bn(self.k, pre_bn, self.norm)
merge_pre_bn(self.v, pre_bn, self.norm)
else:
merge_pre_bn(self.k, pre_bn)
merge_pre_bn(self.v, pre_bn)
self.is_bn_merged = True
def forward(self, x):
B, N, C = x.shape
q = self.q(x)
q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.transpose(1, 2)
x_ = self.sr(x_)
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
x_ = self.norm(x_)
x_ = x_.transpose(1, 2)
k = self.k(x_)
k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
v = self.v(x_)
v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
else:
k = self.k(x)
k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
v = self.v(x)
v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
attn = (q @ k) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class NTB(nn.Module):
"""
Next Transformer Block
"""
def __init__(
self, in_channels, out_channels, path_dropout, stride=1, sr_ratio=1,
mlp_ratio=2, head_dim=32, mix_block_ratio=0.75, attn_drop=0, drop=0,
):
super(NTB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mix_block_ratio = mix_block_ratio
norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS)
self.mhsa_out_channels = _make_divisible(int(out_channels * mix_block_ratio), 32)
self.mhca_out_channels = out_channels - self.mhsa_out_channels
self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, stride)
self.norm1 = norm_func(self.mhsa_out_channels)
self.e_mhsa = E_MHSA(self.mhsa_out_channels, head_dim=head_dim, sr_ratio=sr_ratio,
attn_drop=attn_drop, proj_drop=drop)
self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)
self.projection = PatchEmbed(self.mhsa_out_channels, self.mhca_out_channels, stride=1)
self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))
self.norm2 = norm_func(out_channels)
self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
self.mlp_path_dropout = DropPath(path_dropout)
self.is_bn_merged = False
def merge_bn(self):
if not self.is_bn_merged:
self.e_mhsa.merge_bn(self.norm1)
self.mlp.merge_bn(self.norm2)
self.is_bn_merged = True
def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.shape
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
out = self.norm1(x)
else:
out = x
out = rearrange(out, "b c h w -> b (h w) c") # b n c
out = self.mhsa_path_dropout(self.e_mhsa(out))
x = x + rearrange(out, "b (h w) c -> b c h w", h=H)
out = self.projection(x)
out = out + self.mhca_path_dropout(self.mhca(out))
x = torch.cat([x, out], dim=1)
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
out = self.norm2(x)
else:
out = x
x = x + self.mlp_path_dropout(self.mlp(out))
return x
class NextViT(nn.Module):
def __init__(self, stem_chs, depths, path_dropout, attn_drop=0, drop=0, num_classes=1000,
strides=[1, 2, 2, 2], sr_ratios=[8, 4, 2, 1], head_dim=32, mix_block_ratio=0.75,
use_checkpoint=False, resume='', with_extra_norm=True, frozen_stages=-1,
norm_eval=False, norm_cfg=None,
):
super(NextViT, self).__init__()
self.use_checkpoint = use_checkpoint
self.frozen_stages = frozen_stages
self.with_extra_norm = with_extra_norm
self.norm_eval = norm_eval
self.stage_out_channels = [[96] * (depths[0]),
[192] * (depths[1] - 1) + [256],
[384, 384, 384, 384, 512] * (depths[2] // 5),
[768] * (depths[3] - 1) + [1024]]
# Next Hybrid Strategy
self.stage_block_types = [[NCB] * depths[0],
[NCB] * (depths[1] - 1) + [NTB],
[NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5),
[NCB] * (depths[3] - 1) + [NTB]]
self.stem = nn.Sequential(
ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2),
ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2),
)
input_channel = stem_chs[-1]
features = []
idx = 0
dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))] # stochastic depth decay rule
for stage_id in range(len(depths)):
numrepeat = depths[stage_id]
output_channels = self.stage_out_channels[stage_id]
block_types = self.stage_block_types[stage_id]
for block_id in range(numrepeat):
if strides[stage_id] == 2 and block_id == 0:
stride = 2
else:
stride = 1
output_channel = output_channels[block_id]
block_type = block_types[block_id]
if block_type is NCB:
layer = NCB(input_channel, output_channel, stride=stride, path_dropout=dpr[idx + block_id],
drop=drop, head_dim=head_dim)
features.append(layer)
elif block_type is NTB:
layer = NTB(input_channel, output_channel, path_dropout=dpr[idx + block_id], stride=stride,
sr_ratio=sr_ratios[stage_id], head_dim=head_dim, mix_block_ratio=mix_block_ratio,
attn_drop=attn_drop, drop=drop)
features.append(layer)
input_channel = output_channel
idx += numrepeat
self.features = nn.Sequential(*features)
self.extra_norm_list = None
if with_extra_norm:
self.extra_norm_list = []
for stage_id in range(len(self.stage_out_channels)):
self.extra_norm_list.append(nn.BatchNorm2d(
self.stage_out_channels[stage_id][-1], eps=NORM_EPS))
self.extra_norm_list = nn.Sequential(*self.extra_norm_list)
self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.proj_head = nn.Sequential(
nn.Linear(output_channel, num_classes),
)
self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
print('initialize_weights...')
self._initialize_weights()
if resume:
self.init_weights(resume)
if norm_cfg is not None:
self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages > 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for idx, layer in enumerate(self.features):
if idx <= self.stage_out_idx[self.frozen_stages - 1]:
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(NextViT, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def merge_bn(self):
self.eval()
for idx, module in self.named_modules():
if isinstance(module, NCB) or isinstance(module, NTB):
module.merge_bn()
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
print('\n using pretrained model\n')
checkpoint = torch.load(pretrained, map_location='cpu')['model']
self.load_state_dict(checkpoint, strict=False)
def _initialize_weights(self):
for n, m in self.named_modules():
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
outputs = list()
x = self.stem(x)
stage_id = 0
for idx, layer in enumerate(self.features):
if self.use_checkpoint:
x = checkpoint.checkpoint(layer, x)
else:
x = layer(x)
if idx == self.stage_out_idx[stage_id]:
if self.with_extra_norm:
if stage_id < 3:
x = self.extra_norm_list[stage_id](x)
else:
x = self.norm(x)
outputs.append(x)
stage_id += 1
return outputs
class nextvit_small(NextViT):
def __init__(self, resume='', **kwargs):
super(nextvit_small, self).__init__(
stem_chs=[64, 32, 64], depths=[3, 4, 10, 3], path_dropout=0.2, resume=resume, **kwargs
)
class nextvit_base(NextViT):
def __init__(self, resume='', **kwargs):
super(nextvit_base, self).__init__(
stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2, resume=resume, **kwargs
)
class nextvit_large(NextViT):
def __init__(self, resume='', **kwargs):
super(nextvit_large, self).__init__(
stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.2, resume=resume, **kwargs
)
引用情報
@article{li2022next,
title={Next-ViT: Next Generation Vision Transformer for Efficient Deployment in Realistic Industrial Scenarios},
author={Li, Jiashi and Xia, Xin and Li, Wei and Li, Huixia and Wang, Xing and Xiao, Xuefeng and Wang, Rui and Zheng, Min and Pan, Xin},
journal={arXiv preprint arXiv:2207.05501},
year={2022}
}
バックボーンの動作確認をします。
backbone = nextvit_small()
# forward
inputs = torch.ones((1, 3, 224, 224)) # B, C, H, W
features = backbone(inputs)
# outputs
for feature in features:
print(feature.shape)
では、 Segmentation Models Pytorch で記載していきます。
from typing import List, Dict
from segmentation_models_pytorch. encoders._base import EncoderMixin
class NextViTSamllEncoder(torch.nn.Module, EncoderMixin):
def __init__(self, **kwargs):
super().__init__()
# A number of channels for each encoder feature tensor, list of integers
self._out_channels: List[int] = [96, 256, 512, 1024]
self._depth: int = len(self._out_channels)
self._in_channels: int = 3
self.backbone = nextvit_small()
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
return self.backbone(x)
smp.encoders.encoders["nextvit_small"] = {
"encoder": NextViTSamllEncoder,
"pretrained_settings": {},
"params": {},
}
これで読み込みと動作確認をします。
model = smp.Unet(
encoder_name="nextvit_small",
encoder_weights=None,
encoder_depth=4,
decoder_channels=(512, 256, 96, 96)
)
# forward
inputs = torch.ones((1, 3, 224, 224)) # B, C, H, W
outputs = model(inputs)
# outputs
outputs.shape
動作は問題なくしましたが、入力とサイズが違いますよね?
224
-> 56
と 1/4 になっています。
アーキテクチャーに詳しい人でしたらお察しですが、最初のレイヤーブロックでサイズを落としたり、ViT のようにパッチに変換する場合は、 Channel 方向へ分解しているので元の画像サイズより小さいサイズで出力されることがあります。
セグメンテーションでも PSPNet や DeepLab 系などもそのような 1/2, 1/4 になっていたりします。
この場合は、
- Upsamling で Decoder を丁寧に拡大
- 小さいところから一気に調整
などのように解像度を調整するようにしましょう。
これらの処理によってモデルの癖などが出てくるでしょう。
精度向上へ向けた機能
精度向上のためのいくつかのアイデアもついでに記載します。
AuxiliaryLoss
セグメンテーションの安定化と精度向上のための手法の1つです。
最初に提案されたのは PSPNet
かな?間違ってるかも
Kaggle の 1th 解法などでも見かけます。
では、実装です。
aux_params=dict(
pooling='avg', # one of 'avg', 'max'
dropout=0.5, # dropout ratio, default is None
activation='sigmoid', # activation function, default is None
classes=4, # define number of output labels
)
これで出力する Head
の設定を行います。
model = smp.Unet(
encoder_name="nextvit_small",
encoder_weights=None,
encoder_depth=4,
decoder_channels=(512, 256, 96, 96),
aux_params=aux_params)
mask, label = model(inputs)
mask.shape, label.shape
2.5D - 3D 処理
最初のレイヤーで 3D 処理を挟ませることで、空間方向にも処理するモデルに拡張できます。
class MiddleSkipConnectionEncoder(torch.nn.Module, EncoderMixin):
def __init__(self, **kwargs):
super().__init__()
self._out_channels = [1, 64, 96, 192, 384, 768]
self._depth: int = 5
self._in_channels: int = 1
self.cnn3d = nn.Conv3d(3, 32, kernel_size=3, padding=1)
self.backbone = timm.create_model('maxvit_small_tf_384', pretrained=False, features_only=True, in_chans=32)
def forward(self, x: torch.Tensor):
features_x = self.cnn3d(x.unsqueeze(2)).squeeze(2)
feat2, feat3, feat4, feat5, feat6 = self.backbone(features_x)
return [x[:,1,:,:], feat2, feat3, feat4, feat5, feat6]
smp.encoders.encoders["maxvit384_encoder"] = {
"encoder": MiddleSkipConnectionEncoder,
'params' : {},
'pretrained_settings': {},
}
これはエンコーダーというよりは、エンコーダーの前後でカスタマイズするといったイメージです。
self.cnn3d(x.unsqueeze(2)).squeeze(2)
この部分でカーネルサイズ:3 で前後の画像を織り交ぜて計算をしていますね。
model = smp.Unet(encoder_name='maxvit384_encoder', encoder_weights=None ,classes=1)
# forward
inputs = torch.ones((1, 3, 384, 384)) # B, C, H, W
outputs = model(inputs)
# outputs
outputs.shape
おまけ
Kaggle SenNet コンペ 1th の解法では、 Segmentation Models Pytorch をエンコーダーをカスタムするのに合わせてデコーダーもカスタムしています。
モデルの読み込みは以下のようになっています。
def build_model(backbone, in_channels, num_classes):
model = smp.Unet(
encoder_name=backbone,
encoder_weights=None,
encoder_args={"in_channels": in_channels},
decoder_norm_type="GN",
decoder_act_type="GeLU",
decoder_upsample_method="nearest",
in_channels=in_channels,
classes=num_classes,
activation=None,
)
return model
これは入賞者が独自に以下を行っています。
- BatchNorm と GroupNorm に置換
- ReLU と GELU に置換
Convnext では GELU, Global Response Normalization(GRN) が使用されているのでそれに似たような処理に合わせた感じでしょうかね
具体的なコード
class UnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
norm_type=None,
act_type="ReLU",
attention_type=None,
center=False,
use_checkpoint=False,
scale_factor=2,
upsample_method="nearest",
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
self.use_checkpoint = use_checkpoint
# remove first skip with same spatial resolution
encoder_channels = encoder_channels[1:]
# reverse channels to start from head of encoder
encoder_channels = encoder_channels[::-1]
# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels
if center:
self.center = CenterBlock(head_channels, head_channels, norm_type=norm_type, act_type=act_type)
else:
self.center = nn.Identity()
if not isinstance(scale_factor, (tuple, list)):
scale_factor = [scale_factor] * len(out_channels)
assert len(scale_factor) == len(out_channels)
# combine decoder keyword arguments
kwargs = dict(norm_type=norm_type, act_type=act_type, attention_type=attention_type, upsample_method=upsample_method)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, s, **kwargs)
for in_ch, skip_ch, out_ch, s in zip(in_channels, skip_channels, out_channels, scale_factor)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, *features):
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
head = features[0]
skips = features[1:]
x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
if self.use_checkpoint:
x = checkpoint.checkpoint(decoder_block, x, skip)
else:
x = decoder_block(x, skip)
return x
ここで decode に引数を追加。
さらに、 decoder block で引数を追加します。
from segmentation_models_pytorch.base import modules as md
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
scale_factor=2,
norm_type=None,
act_type="ReLU",
attention_type=None,
upsample_method="nearest",
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
norm_type=norm_type,
act_type=act_type,
)
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
norm_type=norm_type,
act_type=act_type,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
self.scale_factor = scale_factor
self.upsample_method = upsample_method
if upsample_method == "transposed_conv":
self.upsample = nn.Sequential(
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=scale_factor, stride=scale_factor),
md.LayerNorm2d(in_channels),
nn.GELU()
)
def forward(self, x, skip=None):
if self.upsample_method == "transposed_conv":
x = self.upsample(x)
else:
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.upsample_method)
# x = F.interpolate(x, scale_factor=self.scale_factor, mode="bilinear")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
以下の部分で実際に処理するレイヤーを分岐させています。
具体的なコード
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
act_type="ReLU",
norm_type=None
):
if norm_type is None:
norm_type = ""
assert norm_type in ["BN", "LN", "GN", "IN", "InPlaceABN", ""]
if norm_type == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=(norm_type==""),
)
act_type = act_type.lower() if act_type is not None else ""
if act_type == "relu":
act = nn.ReLU(inplace=True)
elif act_type == "gelu":
act = nn.GELU()
elif act_type == "selu":
act = nn.SELU(inplace=True)
elif act_type == "silu":
act = nn.SiLU(inplace=True)
else:
act = nn.Identity()
if norm_type == "InPlaceABN":
norm = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif norm_type == "BN":
norm = nn.BatchNorm2d(out_channels)
elif norm_type == "LN":
norm = LayerNorm2d(out_channels)
elif norm_type == "GN":
norm = nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels)
elif norm_type == "IN":
norm = nn.InstanceNorm2d(out_channels)
else:
norm = nn.Identity()
super(Conv2dReLU, self).__init__(conv, norm, act)
encoder のカスタマイズに加えて deocder のカスタムも必要になった時はこのような勝者のコードを参考にして勉強していきたいですね。
SenNet の上位解法は以下で記載しています。
さいごに
最後まで読んで頂きありがとうございます。
セグメンテーションのタスクは 衛星
や 医療
などではかなり使用されていることが多く、最近は時系列や3Dも上手に処理していくことがモデルや実装には求められます。
モデル側で吸収させるにしても PyTorch で Segmentation Models PyoTroch などのライブラリーを活用することでより単純に記載していきたいですね!
おまけ
こちら以外にも記事執筆やコンペ解法記載をしているのでご参考になれば幸いです
衛星データ解析として、宙畑のライターもしています。
SAR 解析をよくやっていますが、画像系AI、地理空間や衛星データ、点群データに関心があります。
勉強している人は好きなので楽しく絡んでくれると嬉しいです。
Discussion