🎉

gradioでSAM2のデバッグツールを作ってみた

2024/09/08に公開

GenieeのCTO 孟です。

最近SAM2の可能性を模索したいと思いまして、
https://github.com/facebookresearch/segment-anything-2
SAM2に画像とpositionの座標をセットでinputさせる必要があって、なかなかdebugしづらいなと思って、
Claude 3.5 Sonnetを用いて、debugしやすくするようにツール群を1時間程度で作ってもらいました。

まずgradio.pyのコードから:

import gradio as gr
import requests
from PIL import Image
import numpy as np
import io
import base64

# Global variable for API base URL
API_BASE_URL = "http://127.0.0.1:5001/sam2"

def show_coordinates(image, evt: gr.SelectData, normalize: str, epsilon: float):
    if image is None:
        return "Please upload an image first.", None
    
    x, y = evt.index
    
    try:
        # Convert numpy array to PIL Image
        pil_image = Image.fromarray(image.astype('uint8'), 'RGB')

        # Save the image to a bytes buffer
        img_byte_arr = io.BytesIO()
        pil_image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()

        # Make the API call
        nml_param = 1 if normalize == "Apply Normalization" else 0
        url = f"{API_BASE_URL}/predict?x={x}&y={y}&nml={nml_param}&epsilon={epsilon}"
        files = {"image": ("image.png", img_byte_arr, "image/png")}
        response = requests.post(url, files=files)

        if response.status_code == 200:
            result = response.json()
            image_base64 = result.get("image_base64")
            
            if image_base64:
                # Decode base64 string to image
                image_data = base64.b64decode(image_base64)
                result_image = Image.open(io.BytesIO(image_data))
                
                # Convert PIL Image to numpy array for Gradio
                result_numpy = np.array(result_image)
                
                return f"X: {x}, Y: {y}", result_numpy
            else:
                return f"X: {x}, Y: {y}. Error: No image data in response", None
        else:
            return f"X: {x}, Y: {y}. API Error: {response.status_code}", None

    except Exception as e:
        return f"X: {x}, Y: {y}. Error: {str(e)}", None

with gr.Blocks() as demo:
    gr.Markdown("# Image Upload and Cursor Position Demo")

    with gr.Row():
        image_input = gr.Image(type="numpy", label="Upload Image", height=400)
        with gr.Column():
            coordinates = gr.Textbox(label="Cursor Position")
            normalize = gr.Radio(
                ["No Normalization", "Apply Normalization"], 
                label="Edge Normalization", 
                value="No Normalization"
            )
            epsilon = gr.Slider(
                minimum=0.001, 
                maximum=0.1, 
                value=0.02, 
                step=0.001, 
                label="Normalization Strength (Epsilon)"
            )

    result_image = gr.Image(label="Result Image", height=600, show_download_button=True)

    # Set up the event listener
    image_input.select(
        show_coordinates,
        inputs=[image_input, normalize, epsilon],
        outputs=[coordinates, result_image]
    )

    gr.Markdown("""
    ## Instructions:
    1. Upload an image using the 'Upload Image' component.
    2. Select whether to apply edge normalization or not.
    3. Adjust the Normalization Strength (Epsilon) if applying normalization.
    4. Click on any point in the uploaded image.
    5. The coordinates of your click will be displayed, and the result image will be shown below.
    6. You can change the normalization option or strength at any time and click again to see the difference.
    7. If the result image is not fully visible, you can:
       - Click on the image to open it in full size in a new tab.
       - Use the download button to save and view the full image locally.
    """)

demo.launch(server_name="0.0.0.0")

そして、sam2を動かすサーバサイドのコードがこちらです。
サブディレクトリで動かしたい方、top_pathを変えてみるといいです。

from flask import Flask, request, jsonify
from PIL import Image, ImageDraw
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import cv2
import base64
import io

top_path = "/sam2"

app = Flask(__name__)

checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

def apply_mask(image, mask, color, alpha=0.5):
    for c in range(3):
        image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
    return image

def draw_mask(image, mask, color, alpha=0.5, borders=True):
    masked = apply_mask(image.copy(), mask, color, alpha)

    if borders:
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(masked, contours, -1, color, 2)

    return masked

def draw_point(image, x, y, color, size=5):
    cv2.circle(image, (int(x), int(y)), size, color, -1)
    return image

def straighten_mask_edges(mask, epsilon:float=0.02):
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    straight_mask = np.zeros_like(mask, dtype=np.uint8)
    
    for contour in contours:
        epsilon = epsilon * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        cv2.drawContours(straight_mask, [approx], 0, 1, -1)
    
    return straight_mask.astype(bool)

def process_image(image, masks, scores, point_coords, input_labels, normalize=False, epsilon:float=0.02):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        if normalize:
            mask = straighten_mask_edges(mask, epsilon=epsilon)
        
        color = [30/255, 144/255, 255/255]  # Light blue color
        image = draw_mask(image, mask, color, alpha=0.5, borders=True)

    for point, label in zip(point_coords, input_labels):
        color = (0, 255, 0) if label == 1 else (255, 0, 0)
        image = draw_point(image, point[0], point[1], color)

    return image

@app.route(f'{top_path}/predict', methods=['POST'])
def predict():
    x = request.args.get('x', type=int)
    y = request.args.get('y', type=int)
    normalize = request.args.get('nml', type=int, default=0)
    epsilon = request.args.get('epsilon', type=float, default=0.02)

    if x is None or y is None:
        return jsonify({"error": "Missing x or y coordinates"}), 400

    if 'image' not in request.files:
        return jsonify({"error": "No image file provided"}), 400

    image_file = request.files['image']
    image = Image.open(image_file)
    image = np.array(image.convert("RGB"))

    predictor.set_image(image)

    input_point = np.array([[x, y]])
    input_label = np.array([1])

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]

    processed_image = process_image(image, masks, scores, input_point, input_label, normalize=bool(normalize), epsilon=epsilon)

    # Convert the processed image to base64
    img = Image.fromarray(processed_image)
    buffered = io.BytesIO()
    img.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    return jsonify({"image_base64": img_str})

@app.route(f"{top_path}/ping")
def ping():
    return "pong", 200

if __name__ == '__main__':
    app.run(debug=True, port="5001")

そして、そのホスト側のデーモンを、自分の場合はsam2フォルダに直接置いて動かしたので、変えたい方は、パスを変えてみるといいと思います。
SAM2の入れ方

python3 -m venv venv
source venv/bin/activate

git clone https://github.com/facebookresearch/segment-anything-2.git
cd segment-anything-2 & pip install -e .

cd checkpoints && \
./download_ckpts.sh && \
cd ..

そして、Gradioを動かしてWebUI上でこんなことができるようになります。
マウスをクリックした箇所に、Segment Anything 2 (SAM2)で予測させながら、結果をほぼリアルタイムに確認することが可能となります。

エッジ部分の処理をもう少しスムーズ化したかったので、検証したいことをパラメータで調整しながら結果を素早く確認できるツールとして、悪くないじゃないかと思いました。

こちらのレポジトリで少しずつ更新を加えるので、興味ある方watchして頂けると幸いです。
(2点までのマルチポイントpromptの機能追加と、box選択の機能追加を加えてます)
https://github.com/itsusony/sam2_debugger

弊社ジーニーでは常時採用を行っておりますので、最先端の生成AIに限らず、機械学習、データ、BI、アドテク、マーテクに興味をお持ちの方は、こちらをチェックしていただけますと幸いです。

https://hrmos.co/pages/geniee/jobs?category=1365935360094986240

最先端技術に情熱を注ぐ仲間たちと、共に革新的なソリューションを生み出していくことを楽しみにしています。ぜひ私たちの挑戦にご参加ください。

GENIEE TechBlog

Discussion