【PyTorch】CenterNet を実装してみた
はじめに
前回の SimSiam に続き,CenterNet も実際の論文を読んで自分で実装してみました.
CenterNet についての軽い説明と筆者が用いている環境について話します.
本記事の詳細については以下の GitHub にコードを載せているのでよければ見てください.Star や Pull Request 等頂けるとやる気が出ます↓
CenterNet とは
CenterNet とは,アンカーレスな物体検出を行う機械学習モデルで 2019 年にECCV で発表されました.アルゴリズムとしては
- 物体の中心座標のヒートマップ
- 中心座標のオフセット
- 物体のサイズ
の計3つを推論します.
本記事では上記の様なモデルを作成します
論文はこちら↓
アンカーとは
アンカーとは,予め決まられたバウンディングボックスで,k個のアスペクト比の異なるボックスで定義されます.各バウンディングボックスごとに物体検出を行うことで,同時に検出できるオブジェクト数を増加させることができ,YOLOv2 から導入されています.
環境
PC | MacBook Pro (16-inch, 2019) |
---|---|
OS | Monterey |
CPU | 2.3 GHz 8コアIntel Core i9 |
メモリ | 16GB |
Python | 3.9 |
ライブラリ
今回の記事で用いるライブラリとバージョンをまとめますが,特に気にせず
pip install torch torchvision
で問題ないかと思います.
以下,使用するライブラリのバージョンです.
ライブラリ | バージョン |
---|---|
torch | 1.10.1 |
torchvision | 0.11.2 |
階層構造
centernet
└── centernet
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── resnet18.py
│ └── utils
│ └── hub.py
├── centernet.py
├── losses
│ ├── __init__.py
│ ├── gaussian_focal_loss.py
│ └── l1_loss.py
├── modules
│ ├── __init__.py
│ ├── conv_module.py
│ ├── head.py
│ └── neck.py
└── utils
├── __init__.py
├── gaussian_target.py
└── nms.py
実装
以降では,実際にCenterNet
のコードを書いていきます.
なお,本記事ではbackbone
にResNet18
を使用するという前提のもと記述していきます.
CTResNetNeck
ここでは,CTResNetNeck
を実装します.
なお,ConvModule
についてはGitHubのレポジトリで確認してください.
from torch import nn
from .conv_module import ConvModule
class CTResNetNeck(nn.Module):
"""The neck used in `CenterNet <https://arxiv.org/abs/1904.07850>`_ for
object classification and box regression.
Args:
in_channels (int): Number of input channels.
num_deconv_filters (tuple[int]): Number of filters per stage.
num_deconv_kernels (tuple[int]): Number of kernels per stage.
"""
def __init__(
self,
in_channels,
num_deconv_filters,
num_deconv_kernels,
) -> None:
super().__init__()
assert len(num_deconv_filters) == len(num_deconv_kernels)
self.in_channels = in_channels
self.deconv_layers = self._make_deconv_layer(
num_deconv_filters, num_deconv_kernels
)
def _make_deconv_layer(self, num_deconv_filters, num_deconv_kernels):
"""use deconv layers to upsample backbone's output."""
layers = []
for i in range(len(num_deconv_filters)):
feat_channels = num_deconv_filters[i]
conv_module = ConvModule(
in_channels=self.in_channels,
out_channels=feat_channels,
kernel_size=3,
stride=1,
padding=1,
conv_fn="Conv2d",
norm_fn="BatchNorm2d",
)
layers.append(conv_module)
upsample_module = ConvModule(
in_channels=feat_channels,
out_channels=feat_channels,
kernel_size=num_deconv_kernels[i],
stride=2,
padding=1,
conv_fn="ConvTranspose2d",
norm_fn="BatchNorm2d",
)
layers.append(upsample_module)
self.in_channels = feat_channels
return nn.Sequential(*layers)
def forward(self, x) -> torch.Tensor:
outs = self.deconv_layers(x)
return outs
CenterNetHead
from typing import Dict
import torch
from torch import nn
class CenterNetHead(nn.Module):
def __init__(
self, in_channels: int = 64, feat_channels: int = 64, num_classes: int = 4
) -> None:
super().__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.num_classes = num_classes
self.heatmap_head = CenterNetHead._build_head(
in_channels=in_channels,
feat_channels=feat_channels,
out_channels=num_classes,
)
self.wh_head = CenterNetHead._build_head(
in_channels=in_channels, feat_channels=feat_channels, out_channels=2
)
self.offset_head = CenterNetHead._build_head(
in_channels=in_channels, feat_channels=feat_channels, out_channels=2
)
@staticmethod
def _build_head(in_channels, feat_channels, out_channels) -> nn.Module:
"""Build head for each branch."""
layer = nn.Sequential(
nn.Conv2d(in_channels, feat_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(feat_channels, out_channels, kernel_size=1),
)
return layer
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
heatmap = self.heatmap_head(x)
wh = self.wh_head(x)
offset = self.offset_head(x)
return {"heatmap": heatmap, "wh": wh, "offset": offset}
CenterNet
from typing import Dict, Tuple
import torch
from torch import nn
from .backbone import resnet18
from .modules import CenterNetHead, CTResNetNeck
class CenterNet(nn.Module):
def __init__(self, num_classes: int = 4) -> None:
super().__init__()
self.num_classes = num_classes
self.backbone = resnet18(pretrained=True)
self.neck = CTResNetNeck(
in_channels=512,
num_deconv_filters=(256, 128, 64),
num_deconv_kernels=(4, 4, 4),
)
self.bbox_head = CenterNetHead(
in_channels=64, feat_channels=64, num_classes=num_classes
)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x = self.backbone(x)
x = self.neck(x)
feature = self.bbox_head(x)
return feature
モデルの確認
ここで,今回作成したモデルを確認してみましょう.
うまくいっていればモデルをprint
することで確認できます.
CenterNet(
(backbone): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): ResLayer(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): ResLayer(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): ResLayer(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): ResLayer(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(neck): CTResNetNeck(
(deconv_layers): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(2): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(3): ConvModule(
(conv): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(4): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(5): ConvModule(
(conv): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(bbox_head): CenterNetHead(
(heatmap_head): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
)
(wh_head): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
(offset_head): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
)
おわりに
今回は backbone に ResNet18 を使った CenterNet を PyTorch で実装しました.
前回の記事同様に,原論文を読みながらの実装だったので,大変な部分も多々ありましたが,なんとか実装[1]することができました.
間違っている点などありましたら,優しく指摘して頂けると嬉しいです.
-
参考: MMDetection ↩︎
Discussion