🎙️

【超簡単】事前学習済みモデルRetinaNetを使った物体検出をご紹介します

2023/04/23に公開

事前学習済みモデルのRetinaNetを使った物体検出についてご紹介します。
このモデルは、高精度とスピードが特長で幅広い分野で採用されていますが、非常に実装が簡単です。

このブログを最後まで読むと、以下のような実装ができます。
こちらの写真は、有名なビートルズのアルバム「アビィーロード」のジャケット写真です。

モデルの概要

  • RetinaNetは、物体検出のための深層学習モデルの一種であり、ResNet50と呼ばれる深層学習モデルと、FPN(Feature Pyramid Network)と呼ばれる特徴抽出ネットワークを組み合わせて使用することができます。
  • ResNet50では、最終層を物体検出のための検出器に変更し、この検出器を使用して画像内の物体を検出します。
  • FPNは、RetinaNetで使用される特徴抽出ネットワークであり、畳み込み層から抽出された特徴マップを異なる解像度で複数のピラミッド層に分割し、各ピラミッド層で特徴量を補完することで、複数のスケールで物体を検出することができます。このため、RetinaNetは、様々な物体の大きさに対応でき、高い検出精度を発揮することができます。

必要なライブラリーのインストール

import numpy as np
import torch
import torchvision
!pip install pytorch_lightning
import pytorch_lightning as pl
from torchvision import transforms
from PIL import Image
from PIL import ImageDraw, ImageFont
import matplotlib.pyplot as plt

画像ファイルの読み込み

  • 画像を読み込んだうえ、テンソルに変換して表示します。
path = "/content/abbey-road_tcm30-594684.jpeg"
img = Image.open(path)
transform = transforms.ToTensor()
x = transform(img)
plt.imshow(x.permute(1, 2, 0))
  • pathには、画像のファイルパス "/content/abbey-road_tcm30-594684.jpeg" を設定します。
  • Image.open関数を使って、pathで指定された画像ファイルを読み込みます。
  • transforms.ToTensor()関数は、画像をPyTorchのテンソル形式に変換するものです。

これを実行すると、こうなります。オリジナルの画像が表示されますね。

RetinaNet+ResNet50+FPNモデルを使った物体検出

事前学習済みモデルの読み込み

from torchvision.models.detection import retinanet_resnet50_fpn

#seedを固定
pl.seed_everything(0)

#事前学習済みのRetinaNet ResNet50 FPNモデルをインスタンス化
model = retinanet_resnet50_fpn(pretrained=True)

#モデルは初期設定では学習モードになっている。推論結果を確認するためeval()にする
model.eval()
print(model.training)

#推論
#unsqueeze(0)として0次元目にデータ数1を挿入する
#リストで返ってくる。リストの中身は要素1つしかないが、その後の処理をしやすくするため[0]を入れる。(リスト→辞書)
y = model(x.unsqueeze(0))[0]

必要なパラメータの準備

  • 事前準備として、フォントをインストールしておきます。
  • アノテーションに使うラベルデータも用意します。

#フォントのインストール
%%capture
!if [ ! -d fonts ]; then mkdir fonts && cd fonts && wget https://noto-website-2.storage.googleapis.com/pkgs/NotoSansCJKjp-hinted.zip && unzip NotoSansCJKjp-hinted.zip && cd .. ;fi


#ラベルデータ
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

可視化のための関数を用意します。

def visualize_results(input, output, threshold):

    image = input.permute(1, 2, 0).numpy()
    image = Image.fromarray((image*255).astype(np.uint8))
    boxes = output["boxes"].cpu().detach().numpy()
    labels = output["labels"].cpu().detach().numpy()

    if "scores" in output.keys():
        scores = output["scores"].cpu().detach().numpy()
        boxes = boxes[scores > threshold]
        labels = labels[scores > threshold]

    draw = ImageDraw.Draw(image)
    font = ImageFont.truetype("/content/fonts/NotoSansCJKjp-Bold.otf",16)

    for box , label in zip(boxes, labels):
        draw.rectangle(box,outline="blue")
        text = COCO_INSTANCE_CATEGORY_NAMES[label]
        w, h = font.getsize(text)
        draw.rectangle([box[0], box[1], box[0]+w, box[1]+h], fill='blue')
        draw.text((box[0], box[1]), text, font=font, fill="white")

    return image

推論結果の可視化

visualize_results(x, y, 0.5)

うまくいきましたね!

Discussion