🧮

物体検出をDETRモデルで行う

に公開

サマリ

今回はDETRという機械学習モデルを使用して、物体検出を行います。
DETRは対象の画像をCNNで畳み込みTransformerモデル内でベクトルの類似度から、物体を推定するアルゴリズムだったかと覚えています。
上記は完全に正しい説明ではないため参考程度にしていただき、詳しくはリポジトリや論文を参照ください。

マシンスペック

今回はgooglecolabで行います。

画像内に何が写っているか判別する

対話形式で進めます。

# モデルのインポート
import torch as th
import torchvision.transforms as T
import requests
from PIL import Image, ImageDraw, ImageFont

# GithubからDETRのモデルをダウンロード
model = th.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)
model.eval()
# GPUをActivate
model = model.cuda()

Downloading: "https://github.com/facebookresearch/detr/zipball/main" to /root/.cache/torch/hub/main.zip
/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=ResNet50_Weights.IMAGENET1K_V1. You can also use weights=ResNet50_Weights.DEFAULT to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 163MB/s]
Downloading: "https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth" to /root/.cache/torch/hub/checkpoints/detr-r50-e632da11.pth
100%|██████████| 159M/159M [00:01<00:00, 118MB/s]

# 画像を正則化
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# カテゴリを定義
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# 任意の画像を読み込む
url = input()

任意の画像を指定する。

# 画像を読み込みリサイズ
img = Image.open(requests.get(url, stream=True).raw).resize((800,600)).convert('RGB')
# 画像を表示
img

# 画像をテンソル形式に変換、GPUを使用
img_tens = transform(img).unsqueeze(0).cuda()

# 分類の実行
with th.no_grad():
  output = model(img_tens)

# 画像をコピー
im2 = img.copy()
# Pillowの形式に変換
drw = ImageDraw.Draw(im2)

# 各クラスの予測値
pred_logits=output['pred_logits'][0][:, :len(CLASSES)]
# バウンディングボックスの位置
pred_boxes=output['pred_boxes'][0]

# Softmax関数を適用
max_output = pred_logits.softmax(-1).max(-1)
# 物体の検知数の上限を決定
topk = max_output.values.topk(2)

# 上限を適用
pred_logits = pred_logits[topk.indices]
pred_boxes = pred_boxes[topk.indices]
pred_logits.shape

# 検出
for logits, box in zip(pred_logits, pred_boxes):
  # 
  cls = logits.argmax()
  
  if cls >= len(CLASSES):
    continue

  label = CLASSES[cls]
  # 認識された物体のラベルを確認
  print(label)
  # 
  box = box.cpu() * th.Tensor([800, 600, 800, 600])
  x, y, w, h = box
  x0, x1 = x-w//2, x+w//2
  y0, y1 = y-h//2, y+h//2
  # バウンディングボックスの設定
  drw.rectangle([x0, y0, x1, y1], outline='red', width=5)
  # ラベルの設定
  drw.text((x, y), label, fill='white')

cat
bed

# 検出された画像を出力
im2

Discussion