🦁

Segmentation Models Pytorch の BackBone をカスタムする方法

2024/02/07に公開

Segmentation Models Pytorch の バックボーン(Backbone) を設定したいように変更する時の方法と注意点

はじめに

https://pytorch.org/

ディープラーニングのための自動微分をしてくれる PyTroch において、セグメンテーションのタスクのモデルを簡単に提供してくれるのが Segmentation Models Pytorchです。

モデル作成の初学者から Kaggle などのコンペ、業務でも仕様できるくらい便利なライブラリーです。

https://smp.readthedocs.io/en/latest/

もう既にバックボーンとなるエンコーダーは自由に選べるほど豊富にあります。
しかしながら、最新のモデルや3Dカスタムなどをする時に簡単にカスタムする方法がわからなかったので備忘録のつもりで例を記載しております。

実装

Python 実装を記載します。

基本のおさらい

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

ドキュメント通りで上記のように使用できます。

既存エンコーダー

ここで既に提供されているエンコーダーをカスタムする場合は、以下のドキュメントから encoder_name を変更するだけで可能です。

https://smp.readthedocs.io/en/latest/encoders.html

Mix Visual Transformer を仕様したい場合は以下です。

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="mit_b4",     
    encoder_weights="imagenet",     
    classes=3,
)

Timm のエンコーダー

さらに、PyTorch モデルを提供してくれる Pytorch Image Models 通称 Timmからも使用することができます。

https://github.com/huggingface/pytorch-image-models

これでかなりのモデルは仕様できるのではないかと思います。

以下のドキュメントからモデルを選択できます。500以上もの数がありますね。
https://smp.readthedocs.io/en/latest/encoders_timm.html

例えば、Timm のエンコーダーで MaxViT を使いたいと思った場合です。

https://arxiv.org/abs/2204.01697

Timm のモデルは encoder_nametu- + timm の名前 にすると呼び込めます。
maxvit_small_tf_224 の時は、 tu-maxvit_small_tf_224 を指定します。

model = smp.Unet(
    encoder_name="tu-maxvit_small_tf_224",     
    encoder_weights="imagenet",     
    classes=3,
)

カスタムモデル

本題のバックボーンのカスタム方法と例です。

基本的な方法

Segmentation Models PyTorch のライブラリーでエンコーダーを追加したい場合は、以下の2つを実装すれば良いです。

  1. EncoderMixin クラスの継承
  2. smp.encoders.encoders にクラスと名前を登録

例えば、maxvit を入れたい時

ちょっとした豆知識ですが、timm では list_models の関数で実装している名前を部分一致などで検索できます。

import timm
timm.list_models('maxvit*')

まずは簡単に、 Timmmaxvit_small_tf_224 モデルで拡張します。

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import timm

class MiddleSkipConnectionEncoder(nn.Module, EncoderMixin):
    
    def __init__(self, **kwargs):
        super().__init__()
        self._out_channels = [1, 64, 96, 192, 384, 768]
        self._depth: int = 5
        self._in_channels: int = 1

        self.backbone = timm.create_model('maxvit_small_tf_224', 
                                          pretrained=False, 
                                          features_only=True, 
                                          in_chans=self._in_channels)

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)
        return [x[:,0,:,:], *features]

こちらで 1 のクラス継承の手順です。

smp.encoders.encoders["maxvit224_encoder"] = {
    "encoder": MiddleSkipConnectionEncoder,
    'params' : {},
    'pretrained_settings': {},
}

こちらで 2 のクラスと名前の登録です。

では、動作確認をします。

model = smp.Unet(encoder_name='maxvit224_encoder', encoder_weights=None ,classes=1)

# forward
inputs = torch.ones((1, 1, 224, 224)) # B, C, H, W
outputs = model(inputs)

# outputs
outputs.shape

一応、確認のために tu-maxvit_small_tf_224 の場合も載せておきます。

model = smp.Unet(
    encoder_name="tu-maxvit_small_tf_224",     
    encoder_weights="imagenet",     
    classes=3,
)

独自モデル

では、NextViT を持って来たいときは、どうしたら良いでしょうか?

https://arxiv.org/abs/2207.05501

以下の公式のコードを参考にします。
https://github.com/bytedance/Next-ViT/blob/main/segmentation/nextvit.py

エンコーダーの実装

ほとんど写しですが、一部は1つにまとめられるようにしています。

# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial

import torch.utils.checkpoint as checkpoint
from einops import rearrange
from torch.nn.modules.batchnorm import _BatchNorm
from timm.models.layers import DropPath, trunc_normal_

NORM_EPS = 1e-5

def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
    """ Merge pre BN to reduce inference runtime.
    """
    weight = module.weight.data
    if module.bias is None:
        zeros = torch.zeros(module.out_channels, device=weight.device).type(weight.type())
        module.bias = nn.Parameter(zeros)
    bias = module.bias.data
    if pre_bn_2 is None:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        extra_weight = scale_invstd * pre_bn_1.weight
        extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
    else:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        assert pre_bn_2.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_2.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)

        extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
        extra_bias = scale_invstd_2 * pre_bn_2.weight *(pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + pre_bn_2.bias

    if isinstance(module, nn.Linear):
        extra_bias = weight @ extra_bias
        weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
    elif isinstance(module, nn.Conv2d):
        assert weight.shape[2] == 1 and weight.shape[3] == 1
        weight = weight.reshape(weight.shape[0], weight.shape[1])
        extra_bias = weight @ extra_bias
        weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
        weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
    bias.add_(extra_bias)

    module.weight.data = weight
    module.bias.data = bias


class ConvBNReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            groups=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                              padding=1, groups=groups, bias=False)
        self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 2:
            self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))


class MHCA(nn.Module):
    """
    Multi-Head Convolutional Attention
    """
    def __init__(self, out_channels, head_dim):
        super(MHCA, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
                                       padding=1, groups=out_channels // head_dim, bias=False)
        self.norm = norm_layer(out_channels)
        self.act = nn.ReLU(inplace=True)
        self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.group_conv3x3(x)
        out = self.norm(out)
        out = self.act(out)
        out = self.projection(out)
        return out


class Mlp(nn.Module):
    def __init__(self, in_features, out_features=None, mlp_ratio=None, drop=0., bias=True):
        super().__init__()
        out_features = out_features or in_features
        hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
        self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=1, bias=bias)
        self.act = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_dim, out_features, kernel_size=1, bias=bias)
        self.drop = nn.Dropout(drop)

    def merge_bn(self, pre_norm):
        merge_pre_bn(self.conv1, pre_norm)

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.conv2(x)
        x = self.drop(x)
        return x


class NCB(nn.Module):
    """
    Next Convolution Block
    """
    def __init__(self, in_channels, out_channels, stride=1, path_dropout=0,
                 drop=0, head_dim=32, mlp_ratio=3):
        super(NCB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        assert out_channels % head_dim == 0

        self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
        self.mhca = MHCA(out_channels, head_dim)
        self.attention_path_dropout = DropPath(path_dropout)

        self.norm = norm_layer(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
        self.mlp_path_dropout = DropPath(path_dropout)
        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.mlp.merge_bn(self.norm)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class E_MHSA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
                x_ = self.norm(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class NTB(nn.Module):
    """
    Next Transformer Block
    """
    def __init__(
            self, in_channels, out_channels, path_dropout, stride=1, sr_ratio=1,
            mlp_ratio=2, head_dim=32, mix_block_ratio=0.75, attn_drop=0, drop=0,
    ):
        super(NTB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mix_block_ratio = mix_block_ratio
        norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS)

        self.mhsa_out_channels = _make_divisible(int(out_channels * mix_block_ratio), 32)
        self.mhca_out_channels = out_channels - self.mhsa_out_channels

        self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, stride)
        self.norm1 = norm_func(self.mhsa_out_channels)
        self.e_mhsa = E_MHSA(self.mhsa_out_channels, head_dim=head_dim, sr_ratio=sr_ratio,
                             attn_drop=attn_drop, proj_drop=drop)
        self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)

        self.projection = PatchEmbed(self.mhsa_out_channels, self.mhca_out_channels, stride=1)
        self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
        self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))

        self.norm2 = norm_func(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = DropPath(path_dropout)

        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.e_mhsa.merge_bn(self.norm1)
            self.mlp.merge_bn(self.norm2)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm1(x)
        else:
            out = x
        out = rearrange(out, "b c h w -> b (h w) c")  # b n c
        out = self.mhsa_path_dropout(self.e_mhsa(out))
        x = x + rearrange(out, "b (h w) c -> b c h w", h=H)

        out = self.projection(x)
        out = out + self.mhca_path_dropout(self.mhca(out))
        x = torch.cat([x, out], dim=1)

        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm2(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class NextViT(nn.Module):
    def __init__(self, stem_chs, depths, path_dropout, attn_drop=0, drop=0, num_classes=1000,
                 strides=[1, 2, 2, 2], sr_ratios=[8, 4, 2, 1], head_dim=32, mix_block_ratio=0.75,
                 use_checkpoint=False, resume='', with_extra_norm=True, frozen_stages=-1,
                 norm_eval=False, norm_cfg=None,
                 ):
        super(NextViT, self).__init__()
        self.use_checkpoint = use_checkpoint
        self.frozen_stages = frozen_stages
        self.with_extra_norm = with_extra_norm
        self.norm_eval = norm_eval
        self.stage_out_channels = [[96] * (depths[0]),
                                   [192] * (depths[1] - 1) + [256],
                                   [384, 384, 384, 384, 512] * (depths[2] // 5),
                                   [768] * (depths[3] - 1) + [1024]]

        # Next Hybrid Strategy
        self.stage_block_types = [[NCB] * depths[0],
                                  [NCB] * (depths[1] - 1) + [NTB],
                                  [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5),
                                  [NCB] * (depths[3] - 1) + [NTB]]

        self.stem = nn.Sequential(
            ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2),
            ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2),
        )
        input_channel = stem_chs[-1]
        features = []
        idx = 0
        dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))]  # stochastic depth decay rule
        for stage_id in range(len(depths)):
            numrepeat = depths[stage_id]
            output_channels = self.stage_out_channels[stage_id]
            block_types = self.stage_block_types[stage_id]
            for block_id in range(numrepeat):
                if strides[stage_id] == 2 and block_id == 0:
                    stride = 2
                else:
                    stride = 1
                output_channel = output_channels[block_id]
                block_type = block_types[block_id]
                if block_type is NCB:
                    layer = NCB(input_channel, output_channel, stride=stride, path_dropout=dpr[idx + block_id],
                                drop=drop, head_dim=head_dim)
                    features.append(layer)
                elif block_type is NTB:
                    layer = NTB(input_channel, output_channel, path_dropout=dpr[idx + block_id], stride=stride,
                                sr_ratio=sr_ratios[stage_id], head_dim=head_dim, mix_block_ratio=mix_block_ratio,
                                attn_drop=attn_drop, drop=drop)
                    features.append(layer)
                input_channel = output_channel
            idx += numrepeat
        self.features = nn.Sequential(*features)

        self.extra_norm_list = None
        if with_extra_norm:
            self.extra_norm_list = []
            for stage_id in range(len(self.stage_out_channels)):
                self.extra_norm_list.append(nn.BatchNorm2d(
                    self.stage_out_channels[stage_id][-1], eps=NORM_EPS))
            self.extra_norm_list = nn.Sequential(*self.extra_norm_list)

        self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj_head = nn.Sequential(
            nn.Linear(output_channel, num_classes),
        )

        self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
        print('initialize_weights...')
        self._initialize_weights()
        if resume:
            self.init_weights(resume)
        if norm_cfg is not None:
            self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages > 0:
            self.stem.eval()
            for param in self.stem.parameters():
                param.requires_grad = False
            for idx, layer in enumerate(self.features):
                if idx <= self.stage_out_idx[self.frozen_stages - 1]:
                    layer.eval()
                    for param in layer.parameters():
                        param.requires_grad = False

    def train(self, mode=True):
        """Convert the model into training mode while keep normalization layer
        freezed."""
        super(NextViT, self).train(mode)
        self._freeze_stages()
        if mode and self.norm_eval:
            for m in self.modules():
                # trick: eval have effect on BatchNorm only
                if isinstance(m, _BatchNorm):
                    m.eval()

    def merge_bn(self):
        self.eval()
        for idx, module in self.named_modules():
            if isinstance(module, NCB) or isinstance(module, NTB):
                module.merge_bn()

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            print('\n using pretrained model\n')
            checkpoint = torch.load(pretrained, map_location='cpu')['model']
            self.load_state_dict(checkpoint, strict=False)

    def _initialize_weights(self):
        for n, m in self.named_modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        outputs = list()
        x = self.stem(x)
        stage_id = 0
        for idx, layer in enumerate(self.features):
            if self.use_checkpoint:
                x = checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
            if idx == self.stage_out_idx[stage_id]:
                if self.with_extra_norm:
                    if stage_id < 3:
                        x = self.extra_norm_list[stage_id](x)
                    else:
                        x = self.norm(x)
                outputs.append(x)
                stage_id += 1
        return outputs

class nextvit_small(NextViT):
    def __init__(self, resume='', **kwargs):
        super(nextvit_small, self).__init__(
            stem_chs=[64, 32, 64], depths=[3, 4, 10, 3], path_dropout=0.2, resume=resume, **kwargs
        )

class nextvit_base(NextViT):
    def __init__(self, resume='', **kwargs):
        super(nextvit_base, self).__init__(
            stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2, resume=resume, **kwargs
        )
class nextvit_large(NextViT):
    def __init__(self, resume='', **kwargs):
        super(nextvit_large, self).__init__(
            stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.2, resume=resume, **kwargs
        )

引用情報

@article{li2022next,
  title={Next-ViT: Next Generation Vision Transformer for Efficient Deployment in Realistic Industrial Scenarios},
  author={Li, Jiashi and Xia, Xin and Li, Wei and Li, Huixia and Wang, Xing and Xiao, Xuefeng and Wang, Rui and Zheng, Min and Pan, Xin},
  journal={arXiv preprint arXiv:2207.05501},
  year={2022}
}

バックボーンの動作確認をします。

backbone = nextvit_small()

# forward
inputs = torch.ones((1, 3, 224, 224)) # B, C, H, W
features = backbone(inputs)

# outputs
for feature in features:
    print(feature.shape)

では、 Segmentation Models Pytorch で記載していきます。

from typing import List, Dict
from segmentation_models_pytorch. encoders._base import EncoderMixin

class NextViTSamllEncoder(torch.nn.Module, EncoderMixin):
    
    def __init__(self, **kwargs):
        super().__init__()

        # A number of channels for each encoder feature tensor, list of integers
        self._out_channels: List[int] = [96, 256, 512, 1024]
        self._depth: int = len(self._out_channels)
        self._in_channels: int = 3
        self.backbone = nextvit_small()

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:        
        return self.backbone(x)
    
smp.encoders.encoders["nextvit_small"] = {
    "encoder": NextViTSamllEncoder,
    "pretrained_settings": {},
    "params": {},
}

これで読み込みと動作確認をします。

model = smp.Unet(
    encoder_name="nextvit_small", 
    encoder_weights=None, 
    encoder_depth=4,
    decoder_channels=(512, 256, 96, 96)
)

# forward
inputs = torch.ones((1, 3, 224, 224)) # B, C, H, W
outputs = model(inputs)

# outputs
outputs.shape

動作は問題なくしましたが、入力とサイズが違いますよね?
224 -> 56 と 1/4 になっています。
アーキテクチャーに詳しい人でしたらお察しですが、最初のレイヤーブロックでサイズを落としたり、ViT のようにパッチに変換する場合は、 Channel 方向へ分解しているので元の画像サイズより小さいサイズで出力されることがあります。

セグメンテーションでも PSPNet や DeepLab 系などもそのような 1/2, 1/4 になっていたりします。

この場合は、

などのように解像度を調整するようにしましょう。
これらの処理によってモデルの癖などが出てくるでしょう。

精度向上へ向けた機能

精度向上のためのいくつかのアイデアもついでに記載します。

AuxiliaryLoss

セグメンテーションの安定化と精度向上のための手法の1つです。
最初に提案されたのは PSPNet かな?間違ってるかも

https://arxiv.org/abs/1612.01105

Kaggle の 1th 解法などでも見かけます。
https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/447449

では、実装です。

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=4,                 # define number of output labels
)

これで出力する Head の設定を行います。

model = smp.Unet(
    encoder_name="nextvit_small", 
    encoder_weights=None, 
    encoder_depth=4,
    decoder_channels=(512, 256, 96, 96),
    aux_params=aux_params)
mask, label = model(inputs)

mask.shape, label.shape

2.5D - 3D 処理

最初のレイヤーで 3D 処理を挟ませることで、空間方向にも処理するモデルに拡張できます。

class MiddleSkipConnectionEncoder(torch.nn.Module, EncoderMixin):
    
    def __init__(self, **kwargs):
        super().__init__()
        self._out_channels = [1, 64, 96, 192, 384, 768]
        self._depth: int = 5

        self._in_channels: int = 1

        self.cnn3d = nn.Conv3d(3, 32, kernel_size=3, padding=1)

        self.backbone = timm.create_model('maxvit_small_tf_384', pretrained=False, features_only=True, in_chans=32)

    def forward(self, x: torch.Tensor):
        features_x = self.cnn3d(x.unsqueeze(2)).squeeze(2)

        feat2, feat3, feat4, feat5, feat6 = self.backbone(features_x)

        return [x[:,1,:,:], feat2, feat3, feat4, feat5, feat6]

smp.encoders.encoders["maxvit384_encoder"] = {
    "encoder": MiddleSkipConnectionEncoder,
    'params' : {},
    'pretrained_settings': {},
}

これはエンコーダーというよりは、エンコーダーの前後でカスタマイズするといったイメージです。

self.cnn3d(x.unsqueeze(2)).squeeze(2) この部分でカーネルサイズ:3 で前後の画像を織り交ぜて計算をしていますね。

model = smp.Unet(encoder_name='maxvit384_encoder', encoder_weights=None ,classes=1)

# forward
inputs = torch.ones((1, 3, 384, 384)) # B, C, H, W
outputs = model(inputs)

# outputs
outputs.shape

おまけ

Kaggle SenNet コンペ 1th の解法では、 Segmentation Models Pytorch をエンコーダーをカスタムするのに合わせてデコーダーもカスタムしています。
https://www.kaggle.com/competitions/blood-vessel-segmentation/discussion/475522

モデルの読み込みは以下のようになっています。

def build_model(backbone, in_channels, num_classes):
    model = smp.Unet(
        encoder_name=backbone,
        encoder_weights=None,
        encoder_args={"in_channels": in_channels},
        decoder_norm_type="GN",
        decoder_act_type="GeLU",
        decoder_upsample_method="nearest",
        in_channels=in_channels,
        classes=num_classes,
        activation=None,
    )
    return model

これは入賞者が独自に以下を行っています。

  • BatchNorm と GroupNorm に置換
  • ReLU と GELU に置換

Convnext では GELU, Global Response Normalization(GRN) が使用されているのでそれに似たような処理に合わせた感じでしょうかね

具体的なコード
class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        n_blocks=5,
        norm_type=None,
        act_type="ReLU",
        attention_type=None,
        center=False,
        use_checkpoint=False,
        scale_factor=2,
        upsample_method="nearest",
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        self.use_checkpoint = use_checkpoint
        
        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if center:
            self.center = CenterBlock(head_channels, head_channels, norm_type=norm_type, act_type=act_type)
        else:
            self.center = nn.Identity()

        if not isinstance(scale_factor, (tuple, list)):
            scale_factor = [scale_factor] * len(out_channels)
        assert len(scale_factor) == len(out_channels)
        
        # combine decoder keyword arguments
        kwargs = dict(norm_type=norm_type, act_type=act_type, attention_type=attention_type, upsample_method=upsample_method)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, s, **kwargs)
            for in_ch, skip_ch, out_ch, s in zip(in_channels, skip_channels, out_channels, scale_factor)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            if self.use_checkpoint:
                x = checkpoint.checkpoint(decoder_block, x, skip)
            else:
                x = decoder_block(x, skip)

        return x

ここで decode に引数を追加。

https://github.com/jing1tian/blood-vessel-segmentation/blob/31a6a493caecb067fff8a583dd568b5a3ba4dddf/segmentation_models_pytorch/decoders/unet/decoder.py#L94

さらに、 decoder block で引数を追加します。

from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        scale_factor=2,
        norm_type=None,
        act_type="ReLU",
        attention_type=None,
        upsample_method="nearest",
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            norm_type=norm_type,
            act_type=act_type,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            norm_type=norm_type,
            act_type=act_type,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)
        self.scale_factor = scale_factor

        self.upsample_method = upsample_method
        if upsample_method == "transposed_conv":
            self.upsample = nn.Sequential(
                nn.ConvTranspose2d(in_channels, in_channels, kernel_size=scale_factor, stride=scale_factor),
                md.LayerNorm2d(in_channels),
                nn.GELU()
            )
        

    def forward(self, x, skip=None):
        if self.upsample_method == "transposed_conv":
            x = self.upsample(x)
        else:
            x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.upsample_method)
        # x = F.interpolate(x, scale_factor=self.scale_factor, mode="bilinear")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x

以下の部分で実際に処理するレイヤーを分岐させています。
https://github.com/jing1tian/blood-vessel-segmentation/blob/31a6a493caecb067fff8a583dd568b5a3ba4dddf/segmentation_models_pytorch/base/modules.py#L10

具体的なコード
class Conv2dReLU(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        act_type="ReLU",
        norm_type=None
    ):
        
        if norm_type is None:
            norm_type = ""
            
        assert norm_type in ["BN", "LN", "GN", "IN", "InPlaceABN", ""]

        if norm_type == "inplace" and InPlaceABN is None:
            raise RuntimeError(
                "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
                + "To install see: https://github.com/mapillary/inplace_abn"
            )

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=(norm_type==""),
        )
        
        act_type = act_type.lower() if act_type is not None else ""
        if act_type == "relu":
            act = nn.ReLU(inplace=True)
        elif act_type == "gelu":
            act = nn.GELU()
        elif act_type == "selu":
            act = nn.SELU(inplace=True)
        elif act_type == "silu":
            act = nn.SiLU(inplace=True)
        else:
            act = nn.Identity()

        if norm_type == "InPlaceABN":
            norm = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
            relu = nn.Identity()
        elif norm_type == "BN":
            norm = nn.BatchNorm2d(out_channels)
        elif norm_type == "LN":
            norm = LayerNorm2d(out_channels)
        elif norm_type == "GN":
            norm = nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels)
        elif norm_type == "IN":
            norm = nn.InstanceNorm2d(out_channels)
        else:
            norm = nn.Identity()

        super(Conv2dReLU, self).__init__(conv, norm, act)

encoder のカスタマイズに加えて deocder のカスタムも必要になった時はこのような勝者のコードを参考にして勉強していきたいですね。

SenNet の上位解法は以下で記載しています。
https://zenn.dev/syu_tan/articles/859a3fc56c714b

さいごに

最後まで読んで頂きありがとうございます。
セグメンテーションのタスクは 衛星医療 などではかなり使用されていることが多く、最近は時系列や3Dも上手に処理していくことがモデルや実装には求められます。
モデル側で吸収させるにしても PyTorch で Segmentation Models PyoTroch などのライブラリーを活用することでより単純に記載していきたいですね!

おまけ

こちら以外にも記事執筆やコンペ解法記載をしているのでご参考になれば幸いです

https://zenn.dev/syu_tan

https://zenn.dev/syu_tan/articles/fbf0b40aa8c686

衛星データ解析として、宙畑のライターもしています。

https://sorabatake.jp/?s=秀輔

SAR 解析をよくやっていますが、画像系AI、地理空間や衛星データ、点群データに関心があります。
勉強している人は好きなので楽しく絡んでくれると嬉しいです。

https://twitter.com/emmyeil

Discussion