🕵️‍♀️

Gemini-2.0-flashにバウンディングボックス描画をさせて物体検出能力を試す

2024/12/14に公開

執筆日

2024/12/14

概要

12/11にGoogleのバージョンアップしたgemini-2.0が登場しました。
https://blog.google/intl/ja-jp/company-news/technology/google-gemini-ai-update-december-2024/
その中で特に興味を引いたのが画像のバウンディングボックス検出がかなり精度よくできてしまうのが話題になっていたことです(僕が知らなかっただけでgemini-1.5-proの頃からできていたようですが)。これができるとそのまま使うもよし、蒸留学習用のアノテーションデータの大量生成などかなり応用の幅が効き面白いと思っていたので試してみました。

準備

  • 持っていない人はgeminiのAPI KEYを作成しましょう。gemini-2.0-flash-expを使う場合は支払情報を登録して有料ユーザーになっておく必要があります。
  • 依存ライブラリインストール
pip install google-generativeai, google-genai

スクリプト

gemini-2.0発表翌日には公式がcookbookを出してくれていたのでこちらを参考に処理を書いています。(google-colab用なのでローカルjupyterで動かす用に調整、不要な部分を少し修正しています)

折り畳み内にも書いていますが、geminiが返すbbox座標はサイズを1000x1000画像にリサイズしたときの座標を返しているらしくnormalizeが必要になるところが注意が必要な点です。

スクリプト
gemini用のkeyとモデルの設定
from dotenv import load_dotenv
import os
from google import genai
from google.genai import types

load_dotenv(".env")

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME")
client = genai.Client(api_key=GEMINI_API_KEY)
model_name = GEMINI_MODEL_NAME # "gemini-1.5-flash-001","gemini-1.5-pro-002","gemini-2.0-flash-exp"
import
import io
from io import BytesIO
import os
import requests
import json
import random

import google.generativeai as genai
from PIL import Image, ImageDraw, ImageFont, ImageColor

以下の部分がキモで、geminiが返すbbox座標はサイズを1000x1000画像にリサイズしたときの座標を返しているらしくnormalizeが必要になります。画像が正方形なら端から1000x1000にリサイズしちゃってもいいかもしれません。出力が\``json <output>````で囲まれてしまうのを消すのはよくあるやつです。

process gemini output
# @title Parsing JSON output
def parse_json(json_output):
    # We parse out the markdown fencing
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line == "```json":
            json_output = "\n".join(lines[i+1:])  # Remove everything before "```json"
            json_output = json_output.split("```")[0]  # Remove everything after the closing "```"
            break  # Exit the loop once "```json" is found
    return json_output

def plot_bounding_boxes(im: Image.Image, bounding_boxes):
    """
    Plots bounding boxes on an image with markers for each a name, using PIL, normalized coordinates, and different colors.

    Args:
        img_path: The path to the image file.
        bounding_boxes: A list of bounding boxes containing the name of the object
        and their positions in normalized [y1 x1 y2 x2] format.
    """

    # Load the image
    img = im
    width, height = img.size
    print(img.size)
    # Create a drawing object
    draw = ImageDraw.Draw(img)

    # Define a list of colors
    colors = ["red"]

    # We parse out the markdown fencing
    bounding_boxes = parse_json(bounding_boxes)

    # font = ImageFont.truetype("NotoSansCJK-Regular.ttc", size=14)
    font = ImageFont.truetype("meiryo.ttc", size=28, index=1) # サンプルで使われていたフォントがインストール出来なかったためMeiryoを仕様

    # Iterate over the bounding boxes
    for i, bounding_box in enumerate(json.loads(bounding_boxes)):
        # Select a color from the list
        color = colors[i % len(colors)]

        # Convert normalized coordinates to absolute coordinates
        abs_y1 = int(bounding_box["box_2d"][0]/1000 * height)
        abs_x1 = int(bounding_box["box_2d"][1]/1000 * width)
        abs_y2 = int(bounding_box["box_2d"][2]/1000 * height)
        abs_x2 = int(bounding_box["box_2d"][3]/1000 * width)

        if abs_x1 > abs_x2:
            abs_x1, abs_x2 = abs_x2, abs_x1

        if abs_y1 > abs_y2:
            abs_y1, abs_y2 = abs_y2, abs_y1

        # Draw the bounding box
        draw.rectangle(
            ((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4
        )

        # Draw the text
        if "label" in bounding_box:
            draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)

プロンプトを英語と日本語で試してみましたがどちらも問題なく同じ回答を生成してくれました。

gemini response
# bounding_box_system_instructions = """
#     Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
#     If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
# """
# prompt = "Detect the 2d bounding box of the pen (“label” is the “description” of the pen)."  # @param {type:"string"}

bounding_box_system_instructions = """
    バウンディングボックスをラベル付きのJSON配列として返す。マスクやコード・フェンシングは返さない。オブジェクトは25個まで。
    オブジェクトが複数存在する場合は、固有の特性(色、サイズ、位置、固有の特性など)に応じて名前を付ける。
"""
prompt = "ペンの2次元バウンディングボックスを検出する(「label」をペンの「description」とする)。"  # @param {type:"string"}

safety_settings = [
    types.SafetySetting(
        category="HARM_CATEGORY_DANGEROUS_CONTENT",
        threshold="BLOCK_ONLY_HIGH",
    ),
]

# Load and resize image
img = Image.open(BytesIO(open(image_path, "rb").read()))
print("raw image size:", img.size)
im = Image.open(image_path).resize((1024, int(1024 * img.size[1] / img.size[0])), Image.Resampling.LANCZOS)
print("resized image size:", im.size)

# Run model to find bounding boxes
response = client.models.generate_content(
    model=model_name,
    contents=[prompt, im],
    config = types.GenerateContentConfig(
        system_instruction=bounding_box_system_instructions,
        temperature=0.5,
        safety_settings=safety_settings,
    )
)

# Check output
print(response.text)
output
plot_bounding_boxes(im, response.text)
display(im)

普段GPTでガチガチに出力形式をインストラクションしている身としては、こんな指示で安定した出力になるかちょっと不安になりますが何回かやっても問題なく処理できました。こんな一貫性のある人間になりたいです。

結果

部屋にあったペンで試してみました。スマホで撮影したそれなりに画質のいい画像ではありますが、かなり正確に検出できていますね。

おわり

ViTが物体検出でかなり精度よく物体検出できSAMなどのセグメンテーションモデルがすごいというのは知っていましたが、マルチモーダルLLMが既にこのレベルで物体検出も学習できてしまっているのは予想外でした。LLMの情報処理は人間の認知に近いと勝手に思っていたので、座標を数値化するようなことは学んでないと思っていました(人間も視覚情報から物体の位置を定性的に判断することはできますが、それぞれの座標を数値化するのは難しいのと同じことがLLMでも起きると思っていました)。これからRGBだけじゃなくてdepthとかも学習したVLMが出てきたら面白いだろうなあ。
鮮明な画像じゃなくても上手くいくか、複雑な指示でも上手に検出するかなど試してみようと思います。

ヘッドウォータース

Discussion