🎯

【PyTorch】CenterNet を実装してみた

2022/08/01に公開

はじめに

前回の SimSiam に続き,CenterNet も実際の論文を読んで自分で実装してみました.

CenterNet についての軽い説明と筆者が用いている環境について話します.

本記事の詳細については以下の GitHub にコードを載せているのでよければ見てください.Star や Pull Request 等頂けるとやる気が出ます↓

CenterNet とは

CenterNet とは,アンカーレスな物体検出を行う機械学習モデルで 2019 年にECCV で発表されました.アルゴリズムとしては

  1. 物体の中心座標のヒートマップ
  2. 中心座標のオフセット
  3. 物体のサイズ

image_centernet

の計3つを推論します.

CenterNet

本記事では上記の様なモデルを作成します

論文はこちら↓

アンカーとは

アンカーとは,予め決まられたバウンディングボックスで,k個のアスペクト比の異なるボックスで定義されます.各バウンディングボックスごとに物体検出を行うことで,同時に検出できるオブジェクト数を増加させることができ,YOLOv2 から導入されています.

環境

PC MacBook Pro (16-inch, 2019)
OS Monterey
CPU 2.3 GHz 8コアIntel Core i9
メモリ 16GB
Python 3.9

ライブラリ

今回の記事で用いるライブラリとバージョンをまとめますが,特に気にせず

terminal
pip install torch torchvision

で問題ないかと思います.

以下,使用するライブラリのバージョンです.

ライブラリ バージョン
torch 1.10.1
torchvision 0.11.2

階層構造

centernet
└── centernet
    ├── __init__.py
    ├── backbone
    │   ├── __init__.py
    │   ├── resnet18.py
    │   └── utils
    │       └── hub.py
    ├── centernet.py
    ├── losses
    │   ├── __init__.py
    │   ├── gaussian_focal_loss.py
    │   └── l1_loss.py
    ├── modules
    │   ├── __init__.py
    │   ├── conv_module.py
    │   ├── head.py
    │   └── neck.py
    └── utils
        ├── __init__.py
        ├── gaussian_target.py
        └── nms.py

実装

以降では,実際にCenterNetのコードを書いていきます.
なお,本記事ではbackboneResNet18を使用するという前提のもと記述していきます.

CTResNetNeck

ここでは,CTResNetNeckを実装します.
なお,ConvModuleについてはGitHubのレポジトリで確認してください.

centernet/modules/neck.py
from torch import nn

from .conv_module import ConvModule


class CTResNetNeck(nn.Module):
    """The neck used in `CenterNet <https://arxiv.org/abs/1904.07850>`_ for
    object classification and box regression.
    Args:
        in_channels (int): Number of input channels.
        num_deconv_filters (tuple[int]): Number of filters per stage.
        num_deconv_kernels (tuple[int]): Number of kernels per stage.
    """

    def __init__(
        self,
        in_channels,
        num_deconv_filters,
        num_deconv_kernels,
    ) -> None:
        super().__init__()
        assert len(num_deconv_filters) == len(num_deconv_kernels)
        self.in_channels = in_channels
        self.deconv_layers = self._make_deconv_layer(
            num_deconv_filters, num_deconv_kernels
        )

    def _make_deconv_layer(self, num_deconv_filters, num_deconv_kernels):
        """use deconv layers to upsample backbone's output."""
        layers = []
        for i in range(len(num_deconv_filters)):
            feat_channels = num_deconv_filters[i]
            conv_module = ConvModule(
                in_channels=self.in_channels,
                out_channels=feat_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                conv_fn="Conv2d",
                norm_fn="BatchNorm2d",
            )
            layers.append(conv_module)
            upsample_module = ConvModule(
                in_channels=feat_channels,
                out_channels=feat_channels,
                kernel_size=num_deconv_kernels[i],
                stride=2,
                padding=1,
                conv_fn="ConvTranspose2d",
                norm_fn="BatchNorm2d",
            )
            layers.append(upsample_module)
            self.in_channels = feat_channels

        return nn.Sequential(*layers)

    def forward(self, x) -> torch.Tensor:
        outs = self.deconv_layers(x)
        return outs

CenterNetHead

centernet/modules/head.py
from typing import Dict

import torch
from torch import nn


class CenterNetHead(nn.Module):
    def __init__(
        self, in_channels: int = 64, feat_channels: int = 64, num_classes: int = 4
    ) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.feat_channels = feat_channels
        self.num_classes = num_classes

        self.heatmap_head = CenterNetHead._build_head(
            in_channels=in_channels,
            feat_channels=feat_channels,
            out_channels=num_classes,
        )
        self.wh_head = CenterNetHead._build_head(
            in_channels=in_channels, feat_channels=feat_channels, out_channels=2
        )
        self.offset_head = CenterNetHead._build_head(
            in_channels=in_channels, feat_channels=feat_channels, out_channels=2
        )

    @staticmethod
    def _build_head(in_channels, feat_channels, out_channels) -> nn.Module:
        """Build head for each branch."""
        layer = nn.Sequential(
            nn.Conv2d(in_channels, feat_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(feat_channels, out_channels, kernel_size=1),
        )
        return layer

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        heatmap = self.heatmap_head(x)
        wh = self.wh_head(x)
        offset = self.offset_head(x)

        return {"heatmap": heatmap, "wh": wh, "offset": offset}

CenterNet

centernet/centernet.py
from typing import Dict, Tuple

import torch
from torch import nn

from .backbone import resnet18
from .modules import CenterNetHead, CTResNetNeck


class CenterNet(nn.Module):
    def __init__(self, num_classes: int = 4) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.backbone = resnet18(pretrained=True)
        self.neck = CTResNetNeck(
            in_channels=512,
            num_deconv_filters=(256, 128, 64),
            num_deconv_kernels=(4, 4, 4),
        )
        self.bbox_head = CenterNetHead(
            in_channels=64, feat_channels=64, num_classes=num_classes
        )

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x = self.backbone(x)
        x = self.neck(x)
        feature = self.bbox_head(x)

        return feature

モデルの確認

ここで,今回作成したモデルを確認してみましょう.
うまくいっていればモデルをprintすることで確認できます.

CenterNet(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): ResLayer(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (layer2): ResLayer(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (layer3): ResLayer(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (layer4): ResLayer(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
  )
  (neck): CTResNetNeck(
    (deconv_layers): Sequential(
      (0): ConvModule(
        (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (1): ConvModule(
        (conv): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (2): ConvModule(
        (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (3): ConvModule(
        (conv): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (4): ConvModule(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (5): ConvModule(
        (conv): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
  )
  (bbox_head): CenterNetHead(
    (heatmap_head): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh_head): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (offset_head): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

おわりに

今回は backbone に ResNet18 を使った CenterNet を PyTorch で実装しました.
前回の記事同様に,原論文を読みながらの実装だったので,大変な部分も多々ありましたが,なんとか実装[1]することができました.

間違っている点などありましたら,優しく指摘して頂けると嬉しいです.

脚注
  1. 参考: MMDetection ↩︎

GitHubで編集を提案

Discussion