🗿

Semantic Segmentation Model "SegNeXt" 実装と解説

2022/12/08に公開約19,700字

概要

NeurIPS 2022で採択された"SegNeXt"はCNNベースの手法で高い精度と計算効率を示した。
近年のエンコーダデコーダ系のセグメンテーションモデルはTransformerのAttentionをうまい具合にしてSoTAというものが多かったが、Segnextではこれらのモデルで用いられた知見をCNNに持ち込むことで、従来のAttentionの重い計算量を置き換える新しいモデルと、セグメンテーションにおいて重要な知見をもたらした。
Segnextは純粋なCNNのエンコーダと、マルチスケールコンテキストを扱えるデコーダを用いて、効率的なセグメンテーションタスクを行う事ができる。

[1]より、SegNeXtと従来手法の比較

従来のセグメンテーションの知見

有名なモデルではCNNのDeepLabやTransformer系のSegFormerなどがあるが、これらの成功には以下の特性があると考えられる。

  • Transformer系のモデルはエンコーダが強力で、強いBackbornネットワークの寄与が大きいこと
  • 画像内の様々な大きさのオブジェクトに対して密な予測タスクを解くためには複数解像度の特徴をうまく使うこと
  • 空間(縦横)方向のAttentionでセグメンテーションを行う領域内の重要度を扱うために重み付けを行うこと
  • 高解像度画像を扱うため計算量は少ないほうが良いこと

これらを考慮して、従来のCNNエンコーダデコーダモデルを再設計したSegnextを作る。

Segnextはエンコーダの各ブロックでVANの方法に則って単純な乗算でAttentionを行い、デコーダではマルチスケール特徴量を使ってローカル/グローバル両方のコンテキストを得られるHamburgerの方法を使うことで低レベルから高レベルまでの情報を集める。
これによりエンコーダは純粋な畳込みベースのモデルとなり、Transformer系のものと比べ計算効率が高くなったうえ、オブジェクトのディテール処理についてはTransformerのそれよりも良い性能を示すようになった。

一般的なセグメンテーションモデルのエンコーダは分類モデル(例えばResNetやConvNeXtなど)であるが、セグメンテーションという密な予測タスクのためには一般画像分類タスクでの改善は少ないかもしれない。従ってSegnextではエンコーダとデコーダを協調するように作りなおした。

SegNeXtの内容

Segnextの貢献は次の部分にある。

  • Convk×kをConvk×1とConv1×kの組に分解する[4]の成果に、マルチスケールの受容野とAtentionを加えたこと
  • GoogLeNetのようなマルチスケールブランチ構造から特徴を得るが、一般的なものと異なりエンコーダのみをマルチスケール特徴抽出に用いたこと
  • 空間方向のAttentionだけでなくChannel方向のAttentionも行うため、VANのラージカーネルAttentionを用いて、さらにこれをマルチスケール特徴抽出につなげたこと

[2]より、VANのLarge Kernel Attention (a)

実際の全体像はmmsegmentaionのconfigとして以下のように書かれている。

norm_cfg = dict(type='SyncBN', requires_grad=True)
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
find_unused_parameters = True

model = dict(  # Segnext Base Model ADE160k (segnext.base.512x512.ade.160k.py)
    type='EncoderDecoder',
    backbone=dict(
        embed_dims=[64, 128, 320, 512],
        depths=[3, 3, 12, 3],
        init_cfg=dict(type='Pretrained', checkpoint='pretrained/mscan_b.pth'),
        drop_path_rate=0.1),
    decode_head=dict(
        type='LightHamHead',
        in_channels=[128, 320, 512],
        in_index=[1, 2, 3],
        channels=512,
        ham_channels=512,
        dropout_ratio=0.1,
        num_classes=150,
        norm_cfg=ham_norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

data = dict(samples_per_gpu=4)
evaluation = dict(interval=8000, metric='mIoU')
checkpoint_config = dict(by_epoch=False, interval=8000)
# optimizer
optimizer = dict(
    _delete_=True, 
    type='AdamW', 
    lr=0.00006, 
    betas=(0.9, 0.999), 
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
        	'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
       }
   )
)

lr_config = dict(
    _delete_=True, 
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0, 
    min_lr=0.0, 
    by_epoch=False
)

Encoder

SegnextのエンコーダはMSCAN(Multi-Scale Convolutional Attention Network)と呼び、ローカル情報を扱うdepthwise Conv、マルチスケール特徴を扱うmulti-branch depthwise strip Conv、Channel間のコンテキストを扱うConv1×1の3つを用いる。

[1]より、SegNeXtの構成とAttention

mmsegmentation内の実装を簡潔にしたものを見ていく。
まず基本パーツについては次のものがある。

  • StemConv: 畳み込みで[B,3,HW]の画像テンソルを[B,C0,HW/4]にダウンスケールする
  • OverlapPatchEmbed: 通常のConv Patchfyはkernel_size == strideの畳み込みで行うが、この条件を緩めてPatchfy(つまりただのstride付き畳み込み)
  • Mlp: depthwise Convを加えたMLP、Conv1×1 → DWConv3×3 → GELU → Conv1×1 を行う(Conv1×1はLinearと等価であることに注目)
  • AttentionModule: カーネルサイズが7,11,21のdepthwise Convを縦横に分解したものを加算し、マルチスケール特徴に対してAtentionを行う(下図)
  • SpatialAttention: Conv1×1でAttentionModuleを挟んで呼び出す

[2]より、VANに従ったMlpやAttentionの構成

mscan.py
import torch
from torch import nn
from torch.nn import functional as F



class StemConv(nn.Module):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
    ):
        super(StemConv, self).__init__()

        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, out_channels//2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels//2),
            nn.GELU(),
            nn.Conv2d(out_channels//2, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
        )
        return None


    def forward(self, x):
        x = self.proj(x)
        _, _, h, w = x.size()
        x = x.flatten(2).transpose(1, 2)
        return x, h, w



class OverlapPatchEmbed(nn.Module):
    def __init__(self, patch_size=3, stride=2, in_chans=3, embed_dim=768):
        super().__init__()

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=patch_size//2)
        self.norm = nn.BatchNorm2d(embed_dim)
        return None


    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = self.norm(x)
        x = x.flatten(2).transpose(1, 2)
        return x, H, W



class Mlp(nn.Module):
    def __init__(
        self, 
        in_features, 
        hidden_features=None, 
        out_features=None, 
        drop=0.
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features)  # depthwise conv
        self.act = nn.GELU()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)
        return None


    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x



class AttentionModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
        self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
        self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
        self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
        self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
        self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
        self.conv3 = nn.Conv2d(dim, dim, 1)
        return None


    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn_0 = self.conv0_1(attn)
        attn_0 = self.conv0_2(attn_0)
        attn_1 = self.conv1_1(attn)
        attn_1 = self.conv1_2(attn_1)
        attn_2 = self.conv2_1(attn)
        attn_2 = self.conv2_2(attn_2)
        attn = attn + attn_0 + attn_1 + attn_2
        attn = self.conv3(attn)
        return attn * u



class SpatialAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = AttentionModule(d_model)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)
        return None


    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x

これらにより構成されるBlockと全体の構成は以下である。

mscan.py
class Block(nn.Module):
    def __init__(self,
        dim,
        mlp_ratio=4,
    ):
        super().__init__()
        # in training, dropout_ratio = 0.1
        drop=0.
        drop_path=0.

        self.norm1 = nn.BatchNorm2d(dim)
        self.attn = SpatialAttention(dim)
        self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.BatchNorm2d(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
        layer_scale_init_value = 1e-2
        self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        return None


    def forward(self, x, h, w):
        b, n, c = x.shape
        x = x.permute(0, 2, 1).view(b, c, h, w)
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
        x = x.view(b, c, n).permute(0, 2, 1)
        return x



class MSCAN(nn.Module):
    def __init__(self,
        embed_dims=[64, 128, 256, 512],
        depths=[3, 3, 9, 3],
    ):
        super().__init__()
        # SegNeXt Encoder
        #       SegNeXt-T           SegNeXt-S           SegNeXt-B           SegNeXt-L
        # HW/4  C=32,  Block=3      C=64, Block=2       C=64, Block=3       C=64, Block=3
        # HW/8  C=64,  Block=3      C=128, Block=2      C=128, Block=3      C=128, Block=5
        # HW/16 C=160, Block=5      C=320, Block=4      C=320, Block=12     C=320, Block=27
        # HW/32 C=256, Block=2      C=512, Block=2      C=512, Block=3      C=512, Block=3

        self.depths = depths
        self.num_stages = 4

        self.patch_embed0 = StemConv(3, embed_dims[0])
        self.block0 = nn.ModuleList([Block(dim=embed_dims[0]) for _ in range(depths[0])])
        self.norm0 = nn.LayerNorm(embed_dims[0])

        self.patch_embed1 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
        self.block1 = nn.ModuleList([Block(dim=embed_dims[1]) for _ in range(depths[1])])
        self.norm1 = nn.LayerNorm(embed_dims[1])
        
        self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
        self.block2 = nn.ModuleList([Block(dim=embed_dims[2]) for _ in range(depths[2])])
        self.norm2 = nn.LayerNorm(embed_dims[2])

        self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
        self.block3 = nn.ModuleList([Block(dim=embed_dims[3]) for _ in range(depths[3])])
        self.norm3 = nn.LayerNorm(embed_dims[3])
        return None


    def forward(self, x):
        b = x.shape[0]

        x, h, w = self.patch_embed0(x)
        for block in self.block0:
            x = block(x, h, w)
        x = self.norm0(x)
        x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        p0 = x

        x, h, w = self.patch_embed1(x)
        for block in self.block1:
            x = block(x, h, w)
        x = self.norm1(x)
        x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        p1 = x

        x, h, w = self.patch_embed2(x)
        for block in self.block2:
            x = block(x, h, w)
        x = self.norm2(x)
        x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        p2 = x

        x, h, w = self.patch_embed3(x)
        for block in self.block3:
            x = block(x, h, w)
        x = self.norm3(x)
        x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        p3 = x

        return p0, p1, p2, p3

以上によりエンコーダが作られる。
このBlock数やCahnnel数をスケールすることでtiny, small, base, largeモデルが作られる。

[1]より、各スケールの構成

Decoder

セグメンテーションモデルのデコーダは、SegFormerのようなMLP、CNN系のASPPやPSPなどがあるが、Segnextでは複数解像度の特徴マップを結合してHeadの処理を行う。(と書いてあるが、近年のCNN系のものも複数解像度を結合しておこなうFPN的なものを持っているので、マクロデザインに関してはどこが新しいかはわからない)
SegnextのデコーダはLightHamburgerを用いる。LightHamburgerは従来のCNN系のデコーダのHeadより軽量だが、Segnextのエンコーダが強力なため性能と計算効率のトレードオフが良い。
SegFormerのデコーダはstage1~Stage4までの解像度をすべて用いるが、SegnextではStage2~Stage4の3つの解像度を使う。Stage1の特徴マップは低レベルの情報が多すぎて性能が低下し、また計算量もネックになるため使わない。

[1]より、デコーダの比較

LightHamburgerはAttentionを行列分解(MD)によって行うもので、Conv1×1のLinear変換2つで非負行列因子分解(NMF)を挟む構成になっている。低ランク回復問題としてモデル化したMDを解くことで、入力特徴マップを分解し低ランク埋め込みを再構成するHamburgerは、現在のSelf AttentionによるQとKの積のように中間変数として大きな行列を生成する必要がないため時間/空間計算量が軽く、セマンティックセグメンテーションや画像生成など、グローバルコンテキストを学習することが重要なタスクで良い精度/計算量のトレードオフが示されている。

[3]より、Hamburgerモジュールの構成

[3]より、NMFのアルゴリズム

mmsegmentationでの実装を見ていく。mmcvに依存しないように簡潔に改変している。またオリジナルのコードはHamburger本家によるものである。

  • NMF2D: 非負行列因子分解によりテンソルの行列分解でAttentionを計算する。公式実装では_MatrixDecomposition2DBaseを継承する形で書かれている
  • Hamburger: Conv1×1とNMF2Dを呼びHumburgerモジュールを構築する
  • LightHamHead: 3つの特徴マップを入力しセグメンテーションマスクを出力する。引数のChannel数はmmsegmentationのconfigに準ずる
ham_head.py
class NMF2D(nn.Module):
    def __init__(self):
        super().__init__()
        self.spatial = True
        self.S = 1
        self.D = 512
        self.R = 64
        self.train_steps = 6
        self.eval_steps = 7
        self.inv_t = 1  # default 100
        self.eta = 0.9
        self.rand_init = True

        # Use for device verification such as `self.tensordevice.device == torch.device('cpu')`
        self.tensordevice = nn.Parameter(torch.empty(0))
        return None


    def _build_bases(self, B, S, D, R):
        # _MatrixDecomposition2DBase default empty
        if self.tensordevice.device == torch.device('cuda'):
            bases = torch.rand((B * S, D, R)).cuda()
        else:
            bases = torch.rand((B * S, D, R))
        bases = F.normalize(bases, dim=1)
        return bases


    # @torch.no_grad()
    def local_step(self, x, bases, coef):  
        # _MatrixDecomposition2DBase default empty

        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        numerator = torch.bmm(x.transpose(1, 2), bases)
        # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
        # Multiplicative Update
        coef = coef * numerator / (denominator + 1e-6)

        # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
        numerator = torch.bmm(x, coef)
        # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
        denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
        # Multiplicative Update
        bases = bases * numerator / (denominator + 1e-6)

        return bases, coef


    # @torch.no_grad()
    def local_inference(self, x, bases):
        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        coef = torch.bmm(x.transpose(1, 2), bases)
        coef = F.softmax(self.inv_t * coef, dim=-1)

        steps = self.train_steps if self.training else self.eval_steps
        for _ in range(steps):
            bases, coef = self.local_step(x, bases, coef)

        return bases, coef


    def compute_coef(self, x, bases, coef):  
        # _MatrixDecomposition2DBase default empty

        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        numerator = torch.bmm(x.transpose(1, 2), bases)
        # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
        # multiplication update
        coef = coef * numerator / (denominator + 1e-6)

        return coef


    def forward(self, x, return_bases=False):
        B, C, H, W = x.shape

        # (B, C, H, W) -> (B * S, D, N)
        if self.spatial:
            D = C // self.S
            N = H * W
            x = x.view(B * self.S, D, N)
        else:
            D = H * W
            N = C // self.S
            x = x.view(B * self.S, N, D).transpose(1, 2)

        if not self.rand_init and not hasattr(self, 'bases'):
            bases = self._build_bases(1, self.S, D, self.R)
            self.register_buffer('bases', bases)

        # (S, D, R) -> (B * S, D, R)
        if self.rand_init:
            bases = self._build_bases(B, self.S, D, self.R)
        else:
            bases = self.bases.repeat(B, 1, 1)

        bases, coef = self.local_inference(x, bases)

        # (B * S, N, R)
        coef = self.compute_coef(x, bases, coef)

        # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
        x = torch.bmm(bases, coef.transpose(1, 2))

        # (B * S, D, N) -> (B, C, H, W)
        if self.spatial:
            x = x.view(B, C, H, W)
        else:
            x = x.transpose(1, 2).view(B, C, H, W)

        # (B * H, D, R) -> (B, H, N, D)
        bases = bases.view(B, self.S, D, self.R)

        return x



class Hamburger(nn.Module):
    def __init__(self, ham_channels):
        super().__init__()

        self.hamburger = nn.Sequential(
            nn.Conv2d(ham_channels, ham_channels, 1),
            nn.ReLU(inplace=True),
            NMF2D(),
            nn.Conv2d(ham_channels, ham_channels, 1)
        )
        self.relu = nn.ReLU(inplace=True)
        return None


    def forward(self, x):
        ham = self.hamburger(x)
        ham = self.relu(x + ham)
        return ham



class LightHamHead(nn.Module):
    def __init__(self, in_channels=[128, 320, 512], ham_channels=512, channels=512, num_classes=150):
        super().__init__()
        input_chs = sum(in_channels)
        output_chs = channels

        self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bicubic')
        self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bicubic')
        self.squeeze = nn.Conv2d(input_chs, ham_channels, 1)
        self.hamburger = Hamburger(ham_channels)
        self.align = nn.Conv2d(ham_channels, output_chs, 1)
        self.cls_seg = nn.Conv2d(output_chs, num_classes, 1)
        return None


    def forward(self, p1, p2, p3):
        inputs = torch.cat([p1, self.upsample_x2(p2), self.upsample_x4(p3)], dim=1)
        x = self.squeeze(inputs)
        x = self.hamburger(x)
        output = self.align(x)
        output = self.cls_seg(output)
        return output

以上に示したエンコーダデコーダを繋げばSegnextのモデルが作られる。

実験結果

論文の表の通り。
Mask2Formerなどは詳細な比較も欲しい所。

[1]より、SegFormerとの比較

[1]より、ADE20K, Cityscapes, COCO-Stuffなどによるベンチマーク

感想

Transformer系でなくてもCNNでも高速&高精度なセグメンテーションモデルをつくることができるという主旨で、実際従来手法より高解像度なinputにおいてはかなり高速に見える。また最近耳にするVANやHamburgなどの手法が実際に応用されており、これからの流れを掴むにも良い論文だった。
しかしデコーダの実装を見てみると少し複雑に感じるところもある。これをMLPや純粋なConcat → Convによる処理で(SoTAでなくともいいので)もっとシビアに高速化できれば実用上も十分に使えるものになる気がする。
実装コードはApache-2.0 licenseで公開されているため、mmsegmentationに収録されることが期待される。

引用

[1] Meng-Hao Guo, Cheng-Ze Lu, Qibin Hou, Zhengning Liu, Ming-Ming Cheng, Shi-Min Hu, "SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation", https://arxiv.org/abs/2209.08575 NeurIPS 2022, 18 Sep 2022

[2] Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu, "Visual Attention Network", https://arxiv.org/abs/2202.09741 11 Jul 2022

[3] Zhengyang Geng, Meng-Hao Guo, Hongxu Chen, Xia Li, Ke Wei, Zhouchen Lin, "Is Attention Better Than Matrix Decomposition?", https://arxiv.org/abs/2109.04553 ICLR 2021, 28 Dec 2021

[4] Chao Peng, Xiangyu Zhang, Gang Yu, Guiming Luo, Jian Sun, "Large Kernel Matters -- Improve Semantic Segmentation by Global Convolutional Network", https://arxiv.org/abs/1703.02719 IEEE Conf. Comput. Vis. Pattern Recog. pp. 4353–4361, 8 Mar 2017

Discussion

ログインするとコメントできます