🚀

Zero shot object detection-GroundingDINOを試してみた

2023/05/09に公開

Zeroshot object detectionとは

Zeroshot object detectionは新たにクラス数を定めたデータセットなどを用意して再学習を行わなくても、クラス名(猫とか犬とか)を変更するだけでその物体を検出する物体認識の技術です。
プロンプトとモデルを変更するだけで、説明したオブジェクトを検出することができます。
その中でもGrounding DINOがSOTAのモデルとなっています。
https://github.com/IDEA-Research/GroundingDINO

リンク

Colab
github

準備

Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。

環境構築

インストール手順です。重みのダウンロードもここで行なっています。

import os
HOME = "/content"
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
!pip install -q -e .
!pip install -q roboflow

CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(CONFIG_PATH, "; exist:", os.path.isfile(CONFIG_PATH))

%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
WEIGHTS_NAME = "groundingdino_swint_ogc.pth"
WEIGHTS_PATH = os.path.join(HOME, "weights", WEIGHTS_NAME)
print(WEIGHTS_PATH, "; exist:", os.path.isfile(WEIGHTS_PATH))

%cd {HOME}
!mkdir {HOME}/data
%cd {HOME}/data

!wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-2.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-3.jpeg
!wget -q https://media.roboflow.com/notebooks/examples/dog-4.jpeg

%cd {HOME}/GroundingDINO

推論

(1) ライブラリのインポート

import os
import supervision as sv
from PIL import Image
import numpy as np
import cv2

import pycocotools.mask as mask_util
from groundingdino.util.inference import load_model, load_image, predict, annotate

(2)検出モデルのLoad

HOME = "/content"
CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
WEIGHTS_NAME = "groundingdino_swint_ogc.pth"
WEIGHTS_PATH = os.path.join(HOME, "weights", WEIGHTS_NAME)
detection_model = load_model(CONFIG_PATH, WEIGHTS_PATH)

(3) テストデータのダウンロード

!wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg -O /content/test.jpeg

(4)推論
今回は(3)でダウンロードした写真からdogとbuildingを推論してみたいと思います。
dog

IMAGE_PATH = "/content/test.jpeg"

TEXT_PROMPT = "dog"  
BOX_TRESHOLD = 0.3  
TEXT_TRESHOLD = 0.3  

image_source, image = load_image(IMAGE_PATH)

boxes, logits, phrases = predict(
    model=detection_model, 
    image=image, 
    caption=TEXT_PROMPT, 
    box_threshold=BOX_TRESHOLD, 
    text_threshold=TEXT_TRESHOLD
)

annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)

%matplotlib inline  
sv.plot_image(annotated_frame, (16, 16))

building

IMAGE_PATH = "/content/test.jpeg"

TEXT_PROMPT = "building"  # rectangle, region, shape region, segment, fragment
BOX_TRESHOLD = 0.3  # 0.05~0.1 -> 部品点数とか図面の密度に応じている.
TEXT_TRESHOLD = 0.3  # 0.05~0.1

image_source, image = load_image(IMAGE_PATH)

boxes, logits, phrases = predict(
    model=detection_model, 
    image=image, 
    caption=TEXT_PROMPT, 
    box_threshold=BOX_TRESHOLD, 
    text_threshold=TEXT_TRESHOLD
)

annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)

%matplotlib inline  
sv.plot_image(annotated_frame, (16, 16))

推論結果

ただテストサンプルを試しただけだと面白くないので、別のイメージについても色々と試してみましょう。同じ推論パイプラインを実行するので関数を定義しておきます。

def inference_pipeline(image_path, text_prompt, box_threshold=0.3, text_threshold=0.3):
  image_source, image = load_image(image_path)

  boxes, logits, phrases = predict(
      model=detection_model, 
      image=image, 
      caption=text_prompt, 
      box_threshold=box_threshold, 
      text_threshold=text_threshold
  )

  annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)

  %matplotlib inline  
  sv.plot_image(annotated_frame, (16, 16))

(1)サッカー(近い)
画像の準備

!wget https://prtimes.jp/i/1355/5420/resize/d1355-5420-149832-0.jpg -O /content/soccer.jpg

推論

inference_pipeline("/content/soccer.jpg", "player")  # player only
# inference_pipeline("/content/soccer.jpg", "ball", 0.2, 0.2)  # ball only
# inference_pipeline("/content/soccer.jpg", "player,ball", 0.2, 0.2)  # multiple outputs

結果
player

ball

もし検出結果がない場合はbox_threshold, text_thresholdあたりを小さくすると上手くいくケースが多いです。

複数クラスの検出も可能です。
player,ball

(2)サッカー(遠い)
画像の準備

!wget https://prtimes.jp/i/31288/7/resize/d31288-7-332931-0.png -O /content/soccer.png

推論

inference_pipeline("/content/soccer.png", "player", 0.2, 0.2)
# inference_pipeline("/content/soccer.png", "goal", 0.2, 0.2)

結果
Player

goal

(3)交差点
画像準備

!wget https://camera-map.com/wp-content/uploads/ann-shibuya.jpg -O /content/intersection.jpg

推論

inference_pipeline("/content/intersection.jpg", "human", 0.15, 0.15)
# inference_pipeline("/content/intersection.jpg", "car", 0.17, 0.17)

結果
human

Car

最後に

zeroshot object detectionのSOTAであるGrounding DINOを試してみました。COCO datasetsのクラスだけでは検出できないものもpromptや閾値を変更することで検出できるようになっているのがわかったと思います。
zeroshot object detectionは物体検出をするFirst stepとして非常に有効な手段であると考えています。zeroshotで出来ない場合新しくdatasetを作成してyolov8やyolonasのようなモデルでのfinetuningを実行するなどで格段に実装速度が変化すると感じました。
また、Segment Anything等と組み合わせると更なる真価を発揮するのがこのzeroshot、Segment Anythingと組み合わせた事例についても紹介していこうと思うので今後ともよろしくお願いします。

Discussion