Semantic Segmentation "SegNeXt" 実装と解説
Meng-Hao Guo, Cheng-Ze Lu, Qibin Hou, Zhengning Liu, Ming-Ming Cheng, Shi-Min Hu, "SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation", https://arxiv.org/abs/2209.08575 NeurIPS 2022, 18 Sep 2022
ConvNeXtのようにCNN系の手法でViTで発展した手法を倣おうというモチベーションで、Segmentationモデルを良い効率にしたという手法です。
ViT系はパラメータの効率を主張に含めがちですが実装時の計算時間でかなり不利になっていて使われにくいイメージなので、同じ精度でCNNを組めればとても良さそうです。
概要
NeurIPS 2022で採択された"SegNeXt"はCNNベースの手法で高い精度と計算効率を示した。
近年のエンコーダデコーダ系のセグメンテーションモデルはTransformerのAttentionをうまい具合にしてSoTAというものが多かったが、Segnextではこれらのモデルで用いられた知見をCNNに持ち込むことで、従来のAttentionの重い計算量を置き換える新しいモデルと、セグメンテーションにおいて重要な知見をもたらした。
Segnextは純粋なCNNのエンコーダと、マルチスケールコンテキストを扱えるデコーダを用いて、効率的なセグメンテーションタスクを行う事ができる。
[1]より、SegNeXtと従来手法の比較
従来のセグメンテーションの知見
有名なモデルではCNNのDeepLabやTransformer系のSegFormerなどがあるが、これらの成功には以下の特性があると考えられる。
- Transformer系のモデルはエンコーダが強力で、強いBackbornネットワークの寄与が大きいこと
- 画像内の様々な大きさのオブジェクトに対して密な予測タスクを解くためには複数解像度の特徴をうまく使うこと
- 空間(縦横)方向のAttentionでセグメンテーションを行う領域内の重要度を扱うために重み付けを行うこと
- 高解像度画像を扱うため計算量は少ないほうが良いこと
これらを考慮して、従来のCNNエンコーダデコーダモデルを再設計したSegnextを作る。
Segnextはエンコーダの各ブロックでVANの方法に則って単純な乗算でAttentionを行い、デコーダではマルチスケール特徴量を使ってローカル/グローバル両方のコンテキストを得られるHamburgerの方法を使うことで低レベルから高レベルまでの情報を集める。
これによりエンコーダは純粋な畳込みベースのモデルとなり、Transformer系のものと比べ計算効率が高くなったうえ、オブジェクトのディテール処理についてはTransformerのそれよりも良い性能を示すようになった。
一般的なセグメンテーションモデルのエンコーダは分類モデル(例えばResNetやConvNeXtなど)であるが、セグメンテーションという密な予測タスクのためには一般画像分類タスクでの改善は少ないかもしれない。従ってSegnextではエンコーダとデコーダを協調するように作りなおした。
SegNeXtの内容
Segnextの貢献は次の部分にある。
- Convk×kをConvk×1とConv1×kの組に分解する[4]の成果に、マルチスケールの受容野とAtentionを加えたこと
- GoogLeNetのようなマルチスケールブランチ構造から特徴を得るが、一般的なものと異なりエンコーダのみをマルチスケール特徴抽出に用いたこと
- 空間方向のAttentionだけでなくChannel方向のAttentionも行うため、VANのラージカーネルAttentionを用いて、さらにこれをマルチスケール特徴抽出につなげたこと
[2]より、VANのLarge Kernel Attention (a)
実際の全体像はmmsegmentaionのconfigとして以下のように書かれている。
norm_cfg = dict(type='SyncBN', requires_grad=True)
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
find_unused_parameters = True
model = dict( # Segnext Base Model ADE160k (segnext.base.512x512.ade.160k.py)
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 3, 12, 3],
init_cfg=dict(type='Pretrained', checkpoint='pretrained/mscan_b.pth'),
drop_path_rate=0.1),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=512,
ham_channels=512,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
data = dict(samples_per_gpu=4)
evaluation = dict(interval=8000, metric='mIoU')
checkpoint_config = dict(by_epoch=False, interval=8000)
# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}
)
)
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False
)
Encoder
SegnextのエンコーダはMSCAN(Multi-Scale Convolutional Attention Network)と呼び、ローカル情報を扱うdepthwise Conv、マルチスケール特徴を扱うmulti-branch depthwise strip Conv、Channel間のコンテキストを扱うConv1×1の3つを用いる。
[1]より、SegNeXtの構成とAttention
mmsegmentation内の実装を簡潔にしたものを見ていく。
まず基本パーツについては次のものがある。
-
StemConv
: 畳み込みで[B,3,HW]の画像テンソルを[B,C0,HW/4]にダウンスケールする -
OverlapPatchEmbed
: 通常のConv Patchfyはkernel_size == strideの畳み込みで行うが、この条件を緩めてPatchfy(つまりただのstride付き畳み込み) -
Mlp
: depthwise Convを加えたMLP、Conv1×1 → DWConv3×3 → GELU → Conv1×1 を行う(Conv1×1はLinearと等価であることに注目) -
AttentionModule
: カーネルサイズが7,11,21のdepthwise Convを縦横に分解したものを加算し、マルチスケール特徴に対してAtentionを行う(下図) -
SpatialAttention
: Conv1×1でAttentionModule
を挟んで呼び出す
[2]より、VANに従ったMlpやAttentionの構成
import torch
from torch import nn
from torch.nn import functional as F
class StemConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
):
super(StemConv, self).__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channels, out_channels//2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_channels//2),
nn.GELU(),
nn.Conv2d(out_channels//2, out_channels, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
)
return None
def forward(self, x):
x = self.proj(x)
_, _, h, w = x.size()
x = x.flatten(2).transpose(1, 2)
return x, h, w
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size=3, stride=2, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=patch_size//2)
self.norm = nn.BatchNorm2d(embed_dim)
return None
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, H, W
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
drop=0.
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features) # depthwise conv
self.act = nn.GELU()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
return None
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class AttentionModule(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
return None
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
attn = self.conv3(attn)
return attn * u
class SpatialAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = AttentionModule(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
return None
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
これらにより構成されるBlockと全体の構成は以下である。
class Block(nn.Module):
def __init__(self,
dim,
mlp_ratio=4,
):
super().__init__()
# in training, dropout_ratio = 0.1
drop=0.
drop_path=0.
self.norm1 = nn.BatchNorm2d(dim)
self.attn = SpatialAttention(dim)
self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
return None
def forward(self, x, h, w):
b, n, c = x.shape
x = x.permute(0, 2, 1).view(b, c, h, w)
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
x = x.view(b, c, n).permute(0, 2, 1)
return x
class MSCAN(nn.Module):
def __init__(self,
embed_dims=[64, 128, 256, 512],
depths=[3, 3, 9, 3],
):
super().__init__()
# SegNeXt Encoder
# SegNeXt-T SegNeXt-S SegNeXt-B SegNeXt-L
# HW/4 C=32, Block=3 C=64, Block=2 C=64, Block=3 C=64, Block=3
# HW/8 C=64, Block=3 C=128, Block=2 C=128, Block=3 C=128, Block=5
# HW/16 C=160, Block=5 C=320, Block=4 C=320, Block=12 C=320, Block=27
# HW/32 C=256, Block=2 C=512, Block=2 C=512, Block=3 C=512, Block=3
self.depths = depths
self.num_stages = 4
self.patch_embed0 = StemConv(3, embed_dims[0])
self.block0 = nn.ModuleList([Block(dim=embed_dims[0]) for _ in range(depths[0])])
self.norm0 = nn.LayerNorm(embed_dims[0])
self.patch_embed1 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.block1 = nn.ModuleList([Block(dim=embed_dims[1]) for _ in range(depths[1])])
self.norm1 = nn.LayerNorm(embed_dims[1])
self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.block2 = nn.ModuleList([Block(dim=embed_dims[2]) for _ in range(depths[2])])
self.norm2 = nn.LayerNorm(embed_dims[2])
self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
self.block3 = nn.ModuleList([Block(dim=embed_dims[3]) for _ in range(depths[3])])
self.norm3 = nn.LayerNorm(embed_dims[3])
return None
def forward(self, x):
b = x.shape[0]
x, h, w = self.patch_embed0(x)
for block in self.block0:
x = block(x, h, w)
x = self.norm0(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
p0 = x
x, h, w = self.patch_embed1(x)
for block in self.block1:
x = block(x, h, w)
x = self.norm1(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
p1 = x
x, h, w = self.patch_embed2(x)
for block in self.block2:
x = block(x, h, w)
x = self.norm2(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
p2 = x
x, h, w = self.patch_embed3(x)
for block in self.block3:
x = block(x, h, w)
x = self.norm3(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
p3 = x
return p0, p1, p2, p3
以上によりエンコーダが作られる。
このBlock数やCahnnel数をスケールすることでtiny, small, base, largeモデルが作られる。
[1]より、各スケールの構成
Decoder
セグメンテーションモデルのデコーダは、SegFormerのようなMLP、CNN系のASPPやPSPなどがあるが、Segnextでは複数解像度の特徴マップを結合してHeadの処理を行う。(と書いてあるが、近年のCNN系のものも複数解像度を結合しておこなうFPN的なものを持っているので、マクロデザインに関してはどこが新しいかはわからない)
SegnextのデコーダはLightHamburgerを用いる。LightHamburgerは従来のCNN系のデコーダのHeadより軽量だが、Segnextのエンコーダが強力なため性能と計算効率のトレードオフが良い。
SegFormerのデコーダはstage1~Stage4までの解像度をすべて用いるが、SegnextではStage2~Stage4の3つの解像度を使う。Stage1の特徴マップは低レベルの情報が多すぎて性能が低下し、また計算量もネックになるため使わない。
[1]より、デコーダの比較
LightHamburgerはAttentionを行列分解(MD)によって行うもので、Conv1×1のLinear変換2つで非負行列因子分解(NMF)を挟む構成になっている。低ランク回復問題としてモデル化したMDを解くことで、入力特徴マップを分解し低ランク埋め込みを再構成するHamburgerは、現在のSelf AttentionによるQとKの積のように中間変数として大きな行列を生成する必要がないため時間/空間計算量が軽く、セマンティックセグメンテーションや画像生成など、グローバルコンテキストを学習することが重要なタスクで良い精度/計算量のトレードオフが示されている。
[3]より、Hamburgerモジュールの構成
[3]より、NMFのアルゴリズム
mmsegmentationでの実装を見ていく。mmcvに依存しないように簡潔に改変している。またオリジナルのコードはHamburger本家によるものである。
-
NMF2D
: 非負行列因子分解によりテンソルの行列分解でAttentionを計算する。公式実装では_MatrixDecomposition2DBase
を継承する形で書かれている -
Hamburger
: Conv1×1とNMF2D
を呼びHumburgerモジュールを構築する -
LightHamHead
: 3つの特徴マップを入力しセグメンテーションマスクを出力する。引数のChannel数はmmsegmentationのconfigに準ずる
class NMF2D(nn.Module):
def __init__(self):
super().__init__()
self.spatial = True
self.S = 1
self.D = 512
self.R = 64
self.train_steps = 6
self.eval_steps = 7
self.inv_t = 1 # default 100
self.eta = 0.9
self.rand_init = True
# Use for device verification such as `self.tensordevice.device == torch.device('cpu')`
self.tensordevice = nn.Parameter(torch.empty(0))
return None
def _build_bases(self, B, S, D, R):
# _MatrixDecomposition2DBase default empty
if self.tensordevice.device == torch.device('cuda'):
bases = torch.rand((B * S, D, R)).cuda()
else:
bases = torch.rand((B * S, D, R))
bases = F.normalize(bases, dim=1)
return bases
# @torch.no_grad()
def local_step(self, x, bases, coef):
# _MatrixDecomposition2DBase default empty
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
numerator = torch.bmm(x, coef)
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
# @torch.no_grad()
def local_inference(self, x, bases):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.inv_t * coef, dim=-1)
steps = self.train_steps if self.training else self.eval_steps
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
def compute_coef(self, x, bases, coef):
# _MatrixDecomposition2DBase default empty
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
def forward(self, x, return_bases=False):
B, C, H, W = x.shape
# (B, C, H, W) -> (B * S, D, N)
if self.spatial:
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
else:
D = H * W
N = C // self.S
x = x.view(B * self.S, N, D).transpose(1, 2)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R)
else:
bases = self.bases.repeat(B, 1, 1)
bases, coef = self.local_inference(x, bases)
# (B * S, N, R)
coef = self.compute_coef(x, bases, coef)
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (B * S, D, N) -> (B, C, H, W)
if self.spatial:
x = x.view(B, C, H, W)
else:
x = x.transpose(1, 2).view(B, C, H, W)
# (B * H, D, R) -> (B, H, N, D)
bases = bases.view(B, self.S, D, self.R)
return x
class Hamburger(nn.Module):
def __init__(self, ham_channels):
super().__init__()
self.hamburger = nn.Sequential(
nn.Conv2d(ham_channels, ham_channels, 1),
nn.ReLU(inplace=True),
NMF2D(),
nn.Conv2d(ham_channels, ham_channels, 1)
)
self.relu = nn.ReLU(inplace=True)
return None
def forward(self, x):
ham = self.hamburger(x)
ham = self.relu(x + ham)
return ham
class LightHamHead(nn.Module):
def __init__(self, in_channels=[128, 320, 512], ham_channels=512, channels=512, num_classes=150):
super().__init__()
input_chs = sum(in_channels)
output_chs = channels
self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bicubic')
self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bicubic')
self.squeeze = nn.Conv2d(input_chs, ham_channels, 1)
self.hamburger = Hamburger(ham_channels)
self.align = nn.Conv2d(ham_channels, output_chs, 1)
self.cls_seg = nn.Conv2d(output_chs, num_classes, 1)
return None
def forward(self, p1, p2, p3):
inputs = torch.cat([p1, self.upsample_x2(p2), self.upsample_x4(p3)], dim=1)
x = self.squeeze(inputs)
x = self.hamburger(x)
output = self.align(x)
output = self.cls_seg(output)
return output
以上に示したエンコーダデコーダを繋げばSegnextのモデルが作られる。
実験結果
論文の表の通り。
Mask2Formerなどは詳細な比較も欲しい所。
[1]より、SegFormerとの比較
[1]より、ADE20K, Cityscapes, COCO-Stuffなどによるベンチマーク
感想
Transformer系でなくてもCNNでも高速&高精度なセグメンテーションモデルをつくることができるという主旨で、実際従来手法より高解像度なinputにおいてはかなり高速に見える。また最近耳にするVANやHamburgなどの手法が実際に応用されており、これからの流れを掴むにも良い論文だった。
しかしデコーダの実装を見てみると少し複雑に感じるところもある。これをMLPや純粋なConcat → Convによる処理で(SoTAでなくともいいので)もっとシビアに高速化できれば実用上も十分に使えるものになる気がする。
実装コードはApache-2.0 licenseで公開されているため、mmsegmentationに収録されることが期待される。
参照
[1] Meng-Hao Guo, Cheng-Ze Lu, Qibin Hou, Zhengning Liu, Ming-Ming Cheng, Shi-Min Hu, "SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation", https://arxiv.org/abs/2209.08575 NeurIPS 2022, 18 Sep 2022
[2] Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu, "Visual Attention Network", https://arxiv.org/abs/2202.09741 11 Jul 2022
[3] Zhengyang Geng, Meng-Hao Guo, Hongxu Chen, Xia Li, Ke Wei, Zhouchen Lin, "Is Attention Better Than Matrix Decomposition?", https://arxiv.org/abs/2109.04553 ICLR 2021, 28 Dec 2021
[4] Chao Peng, Xiangyu Zhang, Gang Yu, Guiming Luo, Jian Sun, "Large Kernel Matters -- Improve Semantic Segmentation by Global Convolutional Network", https://arxiv.org/abs/1703.02719 IEEE Conf. Comput. Vis. Pattern Recog. pp. 4353–4361, 8 Mar 2017
Discussion