🏝️

ELAN PAFPN by ConvNeXt Architecture 実装と解説

2022/12/05に公開

Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie, "A ConvNet for the 2020s", https://arxiv.org/abs/2201.03545 CVPR 2022, 2 Mar 2022

Chien-Yao Wang, Alexey Bochkovskiy, Hong-Yuan Mark Liao, "YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors", https://arxiv.org/abs/2207.02696 6 Jul 2022

YOLOv7の勉強がてら、そのBlock単位でConvNeXtを組み込めないか試しました。
結果的に失敗したのですが、最近の定番手法を眺めることはできてよかったです。

追記: RTMDetでCSPNeXtという形で似た手法が使われていました。なんか先見の明があった気分で得した気持ち...

概要

CNNで定番として使われるConv1 → BN → ReLU → Conv3 → BN → ReLU → Conv1 → BNの伝統的なBottleneck構造を、ConvnextではTransformerの流儀でConv7 → LN → Conv1 → GELU → Conv1の構成に書き換えた。
このアーキテクチャをYOLOv7のPAFPN(Path Aggregation Feature Pyramid Networks)構成要素に適用したPAFPNextを作った。

モチベーション

Resnetに続くCNN系ベースラインとしてはEfficientNet, NFNetなどがあるが、その新定番として論文のベンチマーク欄に置かれるようになったConvnextは、そのアーキテクチャデザイン、特に7→1→1型Bottleneckやconv7×7、Micro DesignのViT化によって、これまでのBottleneckや連続的なconv3×3→BN→ReLUを用いたモデルを再考させた。

[2]より、ResNetなどの畳み込み層のBottleneck

Convnextは同じパラメータ数帯、同じ入力解像度帯においてCNNの中ではとても良い精度を収めており、高精度なモデルの中ではthroughputも非常に大きく、高速である。
Convnextをバックボーンに用いたMask-RCNNによる物体検出やUperNetによるセグメンテーションなどは強力な成果を上げているが、これらのモデルのPAFPNやHeadは従来のアーキテクチャで構成されているため、まだ改善の余地があるように思える。
本稿ではConvnextの内容を確認し、純粋なCNNを用いた物体検出とセグメンテーションの定番アーキテクチャであるPAFPNのEncoder-Decoder部分を、Convnextの流儀で書き換えたらどうなるかを見ていく。

[1]より、ConvNeXtのImageNet Top1 AccとFLOPS比較

ConvNeXtの構成

Convnextの構成について確認する。
Convnextは、SwinTransformerなど2021年代のViT系モデルの定番アーキテクチャデザインによってResnetを書き換えたモデルで、Resnet 50, Resnet 200からそれぞれConvnext tiny, Convnext Baseが作られた。
学習テクニックも最近の定番を用いて、AdamWとlr_scheduler(warmup cosine)による最適化で、weight decay, optimizer momentum, layer-wise lr decay, weight EMA, randaugment, mixup, cutmix, random erasing, label smoothing, stochastic depthによる正則化などを行い、この結果Acc 76.1% → 78.8%に向上、これをスタートラインとしてモデルの構成を行なう。

[1]より、ResNetからの更新とSwinTransformerの比較

Macro Design

画像情報(特に一般物体認識などの画像)は冗長なので、最初に受ける入力画像をそのまま扱うよりダウンサンプルして扱う方が無駄がない。この部分を"Stem"と呼ぶ。これまでのResNetではstride=2のConv7×7とmax poolingで1/4にダウンサンプリングしていたが、ViT系のモデルではStemは位置エンコーディングや画像の分割(patchify)で16×16などの大きさに切り分けている。Convnextはstride=4のConv4×4でpatchに切り分ける4×4 non-overlapping convolution patchify stemを採用する。

また、feat mapの解像度が同一なConv Blockのシリーズを"Stage"と呼ぶ。このBlock数はResnetでは[3,4,6,3]と適当に決められていたが、Transformerの小さなモデルでは[1,1,3,1]*b、大きなモデルでは[1,1,9,1]*bとなるように構築される。Convnext tinyでは[3,3,9,3]のBlock数でステージを作る。

[1]より、各モデルの構成要素

ResNeXt

Resnextは計算量と精度のバランスが良いため、channel数=グループ数としたConv3×3、つまりdepthwise Convを用いる。depthwise Convはchannel単位のAttentionの役割を果たし、これにConv1×1を組み合わせることでchannel方向の特徴と空間方向の特徴をよしなに分離してくれる。そのままでは計算量と精度が落ちるため、channel数をSwin tinyと同じに増やすことで精度が大きく向上する。おそらくだが、SE-Blockとdepthwise Convは等価な操作のため、depthwise Convを使用する近年のアーキテクチャでは採用されていないのだと思われる。

[3]より、ResNeXtのdepthwise Convイメージ

Inverted Bottleneck

TransformerのBlock内ではMLPを通るとき入力channelは4倍になる。Resnetでは"Bottlenek"層と呼ばれるConv1×1 → depthwise Conv3×3 → Conv1×1の計算グラフにおいて、最初のConv1×1でchannel数を少なくして大きなカーネルの計算量を減らし、最後のConv1×1でchannle数をもとに戻す方法を行っていたが、MobileNetv2に採用され普及した"逆Bottlenek"では、最初のConv1×1でchannel数を増やし、Conv3×3で特徴抽出した後、Conv1×1でchannel数を元に戻している。Convnextでは逆Bottlneckを使うことで特徴抽出能力を増強した。特にConvnext baseはこれにより精度が大きく向上した。

Kernel size

ViTはAttentionによる特徴抽出の需要野が大きい。VGG以降Conv3×3を重ねることで需要野を補う方法が主流になり、最近のハードウェアもそれ用に効率化されてるが、SwinではLocal Attenstionにより少なくとも7×7以上のカーネルでAttentionを計算している。TransformerではAttention計算がMLPの前にあるので、それを再現するため大きなカーネルの畳み込みを逆Bottleneckの前、つまりConv3×3 → Conv1×1 → Conv1×1のように持ってくる。さらにConv3×3のカーネル数は据え置きにすることで、channle数を増やすためのカーネル計算はConv1×1の役割になる。これにより計算量が削減されて副作用で精度も悪くなるが、Conv3×3をConv7×7にすることで精度が向上する。7以上の大きさのカーネルでは逆に精度が悪くなった。

Micro Design

Transformerで用いられている構成をResnetに輸入する。ResnetではBlock内の構成が(1)だったところ、Transformer(pre norm)では(2)の構成なため、ConvnextではReLUをGELUに変更し、構成を(3)に変更した。

(1) Conv → BN → ReLU → Conv → BN → ReLU
(2)LN → Attention → LN → MLP(Conv1×1 → GELU → Conv1×1)
(3)Conv3×3 → LN → Conv1×1 → GELU → Conv1×1

最後に、Resnetでは縦横方向のダウンサンプリングをstride=2のConv3×3で行っていたが、ConvnextではStageとStageの間にstride=2のConv2×2を挟むことで行った。このとき畳み込みの前にLNを置かなければ学習が発散するので注意。

[1]より、ConvNeXt 1BlockとResNet、SwinTransformer

これによりConvnext tinyの精度はSwin tinyの精度公称値81.3%を超え、82.0%となる。計算量(FLOPS)は殆どSwinと同じだが、純粋なCNNで特殊なモジュールが必要ないため、GPUでのthroughputはSwinより高速である。
疑問点として、BNとLNではBNの方が良いとされているが、ConvnextではLNの方が精度に+0.1の寄与があったとのことで、そのような相性問題があるのかは怪しく、BNか、Batch sizeに依存したくないならGroupNormalizeを使う方が良いように思える。またGELUもSwishやMishが良いイメージがあるものの、実際に軽くCifer100などを回してみるとGELUが安定して高精度だったことも不思議だった。

ConvNeXtの実装

timmのものを参考にして直書きした。その際簡潔であることに重きを置いたため一般性をかなり失っている点に注意。

convnext.py
import torch
from torch import nn
from torch.nn import functional as F


# LayerNorm for channels of '2D' spatial NCHW tensors
class LayerNorm2d(nn.LayerNorm):
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)
        return None


    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.permute(0, 3, 1, 2)
        return x



class ConvnextBlock(nn.Module):
    def __init__(self, in_chs, out_chs=None, kernel_size=7, stride=1):
        super().__init__()
        
        out_chs = out_chs or in_chs
        pad_dw = (kernel_size - stride) // 2  # in HW == out HW
        mlp_ratio = 4
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_chs, 
                out_chs, 
                kernel_size=kernel_size, 
                padding=pad_dw, 
                stride=stride, 
                groups=in_chs
            ),  # depthwise conv
            LayerNorm2d(out_chs),
            nn.Conv2d(out_chs, int(mlp_ratio * out_chs), kernel_size=1, bias=True),
            nn.GELU(),
            nn.Conv2d(int(mlp_ratio * out_chs), out_chs, kernel_size=1, bias=True),
        )
        return None


    def forward(self, x):
        shortcut = x
        x = self.conv(x)
        x = x + shortcut
        return x



class ConvnextStage(nn.Module):
    def __init__(self, in_chs, out_chs, kernel_size=7, downsample=True, depth=3):
        super().__init__()

        if downsample:
            self.downsample = nn.Sequential(
                LayerNorm2d(in_chs),
                nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2),
            )
            in_chs = out_chs
        else:
            self.downsample = nn.Identity()

        stage_blocks = []
        for _ in range(depth):
            stage_blocks += [ConvnextBlock(in_chs=in_chs, out_chs=out_chs, kernel_size=kernel_size)]
            in_chs = out_chs
        self.blocks = nn.Sequential(*stage_blocks)
        return None


    def forward(self, x):
        x = self.downsample(x)
        x = self.blocks(x)
        return x

ConvnextStageではまずconv2×2でダウンサンプルを行い、その後ConvnextBlockを回数分呼び出す。
ConvnextBlockでは残差接続とConv3×3, LayerNorm, Conv1×1, GELU, Conv1×1を行う。Conv1×1はConvMLPの構成要素で、これらは特徴マップの整形とnn.Linearを用いたMLPを通すことと同義である。

これを用いてConvnext tinyを作ると次のようになる。
Convnet tinyはStage毎にBlockを[3,3,9,3]個持ち、ステージ間でchannelの深さが[3,96,192,384,768]となる。

convnext.py
model = nn.Sequential(
    nn.Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4)), # Stem
    LayerNorm2d(96),                                     # Stem
    ConvnextStage(96,   96, downsample=False),
    ConvnextStage(96,  192, downsample=True),
    ConvnextStage(192, 384, downsample=True),
    ConvnextStage(384, 768, downsample=True),
    nn.AdaptiveAvgPool2d(1),                             # Head
    nn.Flatten(start_dim=1),                             # Head
    nn.Linear(in_features=768, out_features=NUM_CLS)     # Head
)

PAFPNeXt

timmのベンチマークによれば、Convnext tinyはCSPResnext 50やCSPDarknet 53より高速で、ImageNetの精度も1.5%以上高い。
本章ではDeeplabv3+とYOLOv7(ELAN)を参考にConvnextデザインを適用した汎用的なPAFPNを作る。

[6]より、PAFPNの全体像

DeepLabv3+

セグメンテーションで広くベースラインとして用いられる。
本体の設計はエンコーダとデコーダからなり、エンコーダではXceptionモデルを深く現代的なMicroデザインに修正したAligned Xceptionをバックボーンとして1/4解像度の特徴マップと1/16解像度の特徴マップを出力し、1/16解像度の特徴マップに対して何種類かのdilation(カーネルの膨張率、論文ではAtrous Separable Convolution)で受容野を広くした畳み込みを行いそれらを結合した特徴マップを得るASPP(Atrous Spatial Pyramid Pooling)を通す。

[4]より、ASPPを用いたエンコーダデコーダモデル
[4]より、dilationを設定した畳み込み

デコーダでは低解像度の特徴マップを拡大して高解像度の特徴マップと結合し、いくつか畳み込み層を経てから入力の解像度まで拡大して出力する。dilationを用いてパラメータ数を抑えながら広い受容野を得たことで、解像度別の特徴マップ数を減らして計算量を削減し、それでいてセグメンテーションマスク細部の精度を保つことが可能である。
現在でもUNetとともに医用画像のセグメンテーションなどで定番として使われている。

YOLOv7

1ステージ型物体検出モデルの定番で、2022年のリアルタイム物体検出では最も精度と推論速度の効率が良い。
YOLOv7ではバックボーンにCSPDarknetの後継としてELANを用いる。これは結合によりチャンネル数を増やしていくことで畳み込みの計算量を大きく抑えたモデルで、1/8解像度、1/16解像度、1/32解像度の特徴マップを出力し、PAFPNでそれぞれの特徴マップを畳み込んで解像度をあわせて結合する。

[5]より、ELANとその改善版Extended ELAN

Headでは3つの特徴マップのchannel数をあわせて回帰タスクを行う。Headの畳み込みはRepVGGで用いられるデザイン(reparameterization)を使っていて、学習時はスキップコネクションを付けたグラフを用いるが、推論時はそれを排除したシーケンシャルな畳み込み層になる。これにより学習の発散を抑えながら推論時の速度を底上げを図っている。

[5]より、RepConvの探索

回帰タスクではアンカーボックスと正解BBoxを輸送元輸送先に見立てて最適輸送問題として解く。また、学習時PAFPNの特徴マップを引っ張ってきてAuxiliary Headで回帰タスクを解かせることで、損失の伝搬が深い層の学習を効率化し、性能の向上に寄与している。

特徴抽出部の畳み込みには関係ないため本稿では回帰部分は作らない。Convnext baseのchannle数を用いれば回帰headの入出力channle数が等しくなるので単純な付替えで動くと思う。

PAFPNeXtの実装

方針を考える。まずELAN内の2つのConv3×3により構成される畳み込み層をConv7×7→Conv1×1→Conv1×1に変える。またYOLOオリジナルの実装ではSPPCSPCというSPPベースのものが使われているが、このBottleneck部分をConvnext化する。Conv1×1(線形変換)が1つ挟まれているだけの部分はそのままにして、BNをLNに、SwishをGELUに置き換える。

channel数を圧縮する畳み込みではdepthwise Convが複雑になるため、単純な畳み込みができるConvnext Blockとして次を定義しておく。またConvnextOneではdilationを用いた畳み込みができる。

pafpn.py
class ConvnextOne(nn.Module):
    def __init__(self, in_chs, out_chs, kernel_size=7, stride=1, dilation=1):
        super().__init__()

        middle_chs = min(in_chs, out_chs)  # middle channel should be small
        pad = math.ceil(((kernel_size-1)*dilation+1 - stride) / 2)
        mlp_ratio = 4

        self.stream = nn.Sequential(
            nn.Conv2d(
                in_chs, 
                middle_chs, 
                kernel_size=kernel_size, 
                stride=stride, 
                padding=pad, 
                dilation=dilation,
            ),
            LayerNorm2d(middle_chs),
            nn.Conv2d(middle_chs, int(mlp_ratio * out_chs), kernel_size=1, bias=True),
            nn.GELU(),
            nn.Conv2d(int(mlp_ratio * out_chs), out_chs, kernel_size=1, bias=True),
        )
        return None


    def forward(self, x):
        x = self.stream(x)
        return x

ELAN BlockをConvnextデザインで書き換えたものを定義する。

pafpn.py
class ElanBlock(nn.Module):
    def __init__(self, in_chs, out_chs):
        super().__init__()

        middle_chs = in_chs//2
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_chs, middle_chs, kernel_size=1),
            LayerNorm2d(middle_chs),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_chs, middle_chs, kernel_size=1),
            LayerNorm2d(middle_chs),
            nn.GELU(),
        )
        self.conv3 = ConvnextOne(middle_chs, middle_chs)
        self.conv4 = ConvnextOne(middle_chs, middle_chs)
        self.conv5 = nn.Sequential(
            nn.Conv2d(middle_chs*4, out_chs, kernel_size=1),
            LayerNorm2d(out_chs),
            nn.GELU(),
        )
        return None


    def forward(self, x):
        stream1 = self.conv1(x)
        stream2 = self.conv2(x)
        stream3 = self.conv3(stream2)
        stream4 = self.conv4(stream3)
        output = torch.cat([stream1, stream2, stream3, stream4], dim=1)
        output = self.conv5(output)
        return output

YOLOv7のSPPCSPCを作る。Deeplabv3+は1/16解像度でASPPを通るが、YOLOv7は画像の入力解像度が大きいので1/32解像度で行う。オリジナルではMaxPool2d部分のkernelは[1, 7, 9, 13]だが、ASPPのデザインに従って畳み込みのdilationで代用する。

pafpn.py
class SppcspcBlock(nn.Module):
    def __init__(self, in_chs, out_chs):
        super().__init__()

        self.conv1 = ConvnextOne(in_chs, in_chs//2)
        self.maxpool1 = nn.Identity()
        self.maxpool7 = ConvnextOne(in_chs//2, in_chs//2, kernel_size=3, dilation=3)
        self.maxpool9 = ConvnextOne(in_chs//2, in_chs//2, kernel_size=3, dilation=4)
        self.maxpool13 = ConvnextOne(in_chs//2, in_chs//2, kernel_size=3, dilation=6)
        self.conv2 = ConvnextOne(in_chs*2, in_chs//2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_chs, in_chs//2, kernel_size=1),
            LayerNorm2d(in_chs//2),
            nn.GELU(),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_chs, out_chs, kernel_size=1),
            LayerNorm2d(out_chs),
            nn.GELU(),
        )
        
        return None


    def forward(self, x):
        output = self.conv1(x)
        output = torch.cat([
            self.maxpool1(output),
            self.maxpool7(output),
            self.maxpool9(output),
            self.maxpool13(output),
        ], dim=1)
        output = self.conv2(output)
        x = self.conv3(x)
        output = torch.cat([output, x], dim=1)
        output = self.conv4(output)

        return output

RepConvは学習時は残差接続の加算を伴うが、推論時はメインの畳み込みのみになる。
ただ、YOLOv7の実装では解像度ごとに合わせて3blockしかないので、この変更がどれだけ効いているのかは不明。実際通用GPUモデル以外では外されていて、ベンチマーク指標を魅せるインパクトのための最終調整以上の良さは無いかもしれない。

pafpn.py
class RepConv(nn.Module):
    def __init__(self, in_chs):
        super().__init__()

        self.stream = nn.Sequential(
            nn.Conv2d(in_chs, in_chs, kernel_size=3, padding=1),
            LayerNorm2d(in_chs),
        )
        self.rep1 = nn.Sequential(
            nn.Conv2d(in_chs, in_chs, kernel_size=3, padding=1),
            LayerNorm2d(in_chs),
        )
        self.rep2 = LayerNorm2d(in_chs)
        self.conv = nn.Sequential(
            nn.Conv2d(in_chs, in_chs, kernel_size=1),
            LayerNorm2d(in_chs),
            nn.Sigmoid()
        )
        return None


    def forward(self, x):
        if self.training:
            x = self.stream(x) + self.rep1(x)  + self.rep2(x)
        else: 
            x = self.stream(x)
        x = self.conv(x)
        return x

これらを用いてBackbornとPAFPNを組むと以下のようになる。

pafpn.py
class BackbornElanext(nn.Module):
    def __init__(self, depth=(3, 128, 256, 512, 1024)):
        super().__init__()

        # [B, C0, HW] -> [B, C1, HW/2]
        self.stem = nn.Sequential(
            nn.Conv2d(depth[0], depth[1], kernel_size=4, stride=2, padding=1),
            LayerNorm2d(depth[1])
        )
        # [B, C1, HW/2] -> [B, C2, HW/4]
        self.convnext = ConvnextStage(depth[1], depth[2], depth=3)
        # [B, C2, HW/4] -> [B, C2, HW/4]
        self.elan1 = ElanBlock(depth[2], depth[2])
        # [B, C2, HW/8] -> [B, C2, HW/8]
        self.mp1 = ConvnextStage(depth[2], depth[2], depth=1)
        # [B, C2, HW/8] -> [B, C3, HW/8]  # -> c3
        self.elan2 = ElanBlock(depth[2], depth[3])
        # [B, C3, HW/8] -> [B, C3, HW/16]
        self.mp2 =  ConvnextStage(depth[3], depth[3], depth=1)  
        # [B, C3, HW/16] -> [B, C4, HW/16]  # -> c4
        self.elan3 = ElanBlock(depth[3], depth[4])
        # [B, C4, HW/16] -> [B, C4, HW/32]
        self.mp3 =  ConvnextStage(depth[4], depth[4], depth=1)
        # [B, C4, HW/32] -> [B, C4, HW/32]  # -> c5
        self.elan4 = ElanBlock(depth[4], depth[4])

        # [B, C3, HW/8] -> [B, C2, HW/8]  # -> c3
        self.conv3 = nn.Sequential(
            nn.Conv2d(depth[3], depth[2], kernel_size=1),
            LayerNorm2d(depth[2]),
            nn.GELU(),
        )
        # [B, C4, HW/16] -> [B, C3, HW/16]  # -> c4
        self.conv4 = nn.Sequential(
            nn.Conv2d(depth[4], depth[3], kernel_size=1),
            LayerNorm2d(depth[3]),
            nn.GELU(),
        )

        return None


    def forward(self,x):
        x = self.stem(x)
        x = self.convnext(x)
        x = self.elan1(x)
        x = self.mp1(x)
        c3 = self.elan2(x)
        x = self.mp2(c3)
        c4 = self.elan3(x)
        x = self.mp3(c4)

        c3 = self.conv3(c3)
        c4 = self.conv4(c4)
        c5 = self.elan4(x)
        
        return c3, c4, c5



class PafpnElanext(nn.Module):
    def __init__(self, depth=(3, 128, 256, 512, 1024), out_chs=None):
        super().__init__()

        out_chs = out_chs or depth[2]
 
        # FPN Neck
        # [B, C4, HW/32] -> [B, C3, HW/32]  # -> p5
        self.cppcspc = SppcspcBlock(depth[4], depth[4])
        # [B, C3, HW/32] -> [B, C3, HW/16]
        self.upsample5 = nn.Sequential(
            nn.Conv2d(depth[4], depth[3], kernel_size=1),
            LayerNorm2d(depth[3]),
            nn.GELU(),
            nn.Upsample(scale_factor=2),
        )
        # [B, C3+C3, HW/16] -> [B, C2, HW/16]  # -> p4
        self.elan4 = ElanBlock(depth[3]+depth[3], depth[3])
        # [B, C2, HW/16] -> [B, C2, HW/8]
        self.upsample4 = nn.Sequential(
            nn.Conv2d(depth[3], depth[2], kernel_size=1),
            LayerNorm2d(depth[2]),
            nn.GELU(),
            nn.Upsample(scale_factor=2),
        )
        # [B, C2+C2, HW/8] -> [B, C2, HW/8]  # -> p3
        self.elan3 = ElanBlock(depth[2]+depth[2], depth[2])

        # PA Neck
        # [B, C2, HW/8] -> n3 [B, {out_chs}, HW/8]  # -> n3
        self.conv3_channelmatch = nn.Sequential(
            nn.Conv2d(depth[2], out_chs, kernel_size=1),
            LayerNorm2d(depth[2]),
            nn.GELU(),
        )
        # [B, C2, HW/8] -> [B, C3, HW/16]
        self.downsample4 = ConvnextStage(depth[2], depth[3], depth=1)
        # [B, C3+C3, HW/16] -> [B, {out_chs}, HW/16]  # -> n4
        self.elan4_pa = ElanBlock(depth[3]+depth[3], out_chs)
        # [B, C3, HW/16] -> [B, C4, HW/32]
        self.downsample5 = ConvnextStage(out_chs, depth[4], depth=1)
        # [B, C4+{out_chs}, HW/16] -> [B, {out_chs}, HW/16]  # -> n5
        self.elan5_pa = ElanBlock(depth[4]+depth[4], out_chs)


        self.rep3 = RepConv(out_chs)
        self.rep4 = RepConv(out_chs)
        self.rep5 = RepConv(out_chs)

        return None


    def forward(self, c3, c4, c5):
        # c3 [B, C2, HW/8], c4 [B, C3, HW/16], c5 [B, C4, HW/32]
        
        p5 = self.cppcspc(c5)
        p4 = self.upsample5(p5)
        p4 = self.elan4(torch.cat([c4, p4], dim=1))
        p3 = self.upsample4(p4)
        p3 = self.elan3(torch.cat([c3, p3], dim=1))
        
        # p3 [B, C2, HW/8], p4 [B, C2, HW/16], p5 [B, C3, HW/32]

        n3 = self.conv3_channelmatch(p3)
        n4 = self.downsample4(p3)
        n4 = self.elan4_pa(torch.cat([p4, n4], dim=1))
        n5 = self.downsample5(n4)
        n5 = self.elan5_pa(torch.cat([p5, n5], dim=1))

        n3 = self.rep3(n3)
        n4 = self.rep4(n4)
        n5 = self.rep5(n5)
        
        # n3 [B, feat, HW/8], n4 [B, feat, HW/16], n5 [B, feat, HW/32]
        return n3, n4, n5

以上によりConvnextデザインのYOLOv7 ELAN PAFPNを作ることができた。
実際の出力は次のようになる。この特徴マップはYOLOv7と同じshpaeを持っているため、そのままYOLOv7回帰Headへ与えてClass_logit, IoU, BBoxを計算し、Lossの合計をバックワードすれば物体検出の一通りの流れが可能になる。Auxiliary Lossも中間出力を引っ張って来くれば実装に結合できるので、少し弄れば学習可能になると思う。

backborn = BackbornElanext()
pafpn = PafpnElanext()

p3, p4, p5 = backborn(torch.ones([1,3,640,640]))
print(p3.shape, p4.shape, p5.shape)
# torch.Size([1, 256, 80, 80]) torch.Size([1, 512, 40, 40]) torch.Size([1, 1024, 20, 20])

n3, n4, n5 = pafpn(p3, p4, p5)
print(n3.shape, n4.shape, n5.shape)
# torch.Size([1, 256, 80, 80]) torch.Size([1, 256, 40, 40]) torch.Size([1, 256, 20, 20])

ただし、PAFPNは解像度方向のUpsamplingやDwonsamplingが多いため、ConvnextではNormレイヤーが少ないこともあり学習が発散する可能性が高い。実際にYOLOX Headに結合して数iteration回してみたが、loss=nan, iou=nan.0, obj=8.9e+3, cls=2.82e-6という有様だったため改善の必要がある。

おわりに

Convnextのような「ViTで用いられる方法をCNNに適用する」アプローチは、2022年の画像分野の基礎タスクで良い精度/速度の効率を打ち出すアーキテクチャデザインとしてインパクトのある結果を残した。例えばNIPSでAttentionそのものをdepthwise Convで構築することでSegformerを効率化したモデル[7]が採択されており、これは以降のベースラインとなるように思われる。
また今回作ったELANextは学習的な意味が強かったので実験をしなかった(事前学習した重みがあるモデルを使ったほうが実用上お得で自作する必要性は殆どないため...)が、物体認識モデルでもこのようなアイディアのモデルが出てくると、モジュール単位の比較ができるようになって楽しいので誰かがんばってください(他力本願)。

参照

[1] Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie, "A ConvNet for the 2020s", https://arxiv.org/abs/2201.03545 CVPR 2022, 2 Mar 2022

[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, "Deep Residual Learning for Image Recognition", https://arxiv.org/abs/1512.03385 10 Dec 2015
[3] Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He, "Aggregated Residual Transformations for Deep Neural Networks", https://arxiv.org/abs/1611.05431 CVPR 2017, 11 Apr 2017

[4] Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam, "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation", https://arxiv.org/abs/1802.02611 ECCV 2018, 22 Aug 2018

[5] Chien-Yao Wang, Alexey Bochkovskiy, Hong-Yuan Mark Liao, "YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors", https://arxiv.org/abs/2207.02696 6 Jul 2022

[6] Shu Liu, Lu Qi, Haifang Qin, Jianping Shi, Jiaya Jia, "Path Aggregation Network for Instance Segmentation", https://arxiv.org/abs/1803.01534v4 CVPR 2018, 18 Sep 2018

[7] Meng-Hao Guo, Cheng-Ze Lu, Qibin Hou, Zhengning Liu, Ming-Ming Cheng, Shi-Min Hu, "SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation", https://arxiv.org/abs/2209.08575 NeurIPS 2022, 18 Sep 2022

Discussion