🎉
gradioでSAM2のデバッグツールを作ってみた
GenieeのCTO 孟です。
最近SAM2の可能性を模索したいと思いまして、
Claude 3.5 Sonnetを用いて、debugしやすくするようにツール群を1時間程度で作ってもらいました。
まずgradio.pyのコードから:
import gradio as gr
import requests
from PIL import Image
import numpy as np
import io
import base64
# Global variable for API base URL
API_BASE_URL = "http://127.0.0.1:5001/sam2"
def show_coordinates(image, evt: gr.SelectData, normalize: str, epsilon: float):
if image is None:
return "Please upload an image first.", None
x, y = evt.index
try:
# Convert numpy array to PIL Image
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
# Save the image to a bytes buffer
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Make the API call
nml_param = 1 if normalize == "Apply Normalization" else 0
url = f"{API_BASE_URL}/predict?x={x}&y={y}&nml={nml_param}&epsilon={epsilon}"
files = {"image": ("image.png", img_byte_arr, "image/png")}
response = requests.post(url, files=files)
if response.status_code == 200:
result = response.json()
image_base64 = result.get("image_base64")
if image_base64:
# Decode base64 string to image
image_data = base64.b64decode(image_base64)
result_image = Image.open(io.BytesIO(image_data))
# Convert PIL Image to numpy array for Gradio
result_numpy = np.array(result_image)
return f"X: {x}, Y: {y}", result_numpy
else:
return f"X: {x}, Y: {y}. Error: No image data in response", None
else:
return f"X: {x}, Y: {y}. API Error: {response.status_code}", None
except Exception as e:
return f"X: {x}, Y: {y}. Error: {str(e)}", None
with gr.Blocks() as demo:
gr.Markdown("# Image Upload and Cursor Position Demo")
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload Image", height=400)
with gr.Column():
coordinates = gr.Textbox(label="Cursor Position")
normalize = gr.Radio(
["No Normalization", "Apply Normalization"],
label="Edge Normalization",
value="No Normalization"
)
epsilon = gr.Slider(
minimum=0.001,
maximum=0.1,
value=0.02,
step=0.001,
label="Normalization Strength (Epsilon)"
)
result_image = gr.Image(label="Result Image", height=600, show_download_button=True)
# Set up the event listener
image_input.select(
show_coordinates,
inputs=[image_input, normalize, epsilon],
outputs=[coordinates, result_image]
)
gr.Markdown("""
## Instructions:
1. Upload an image using the 'Upload Image' component.
2. Select whether to apply edge normalization or not.
3. Adjust the Normalization Strength (Epsilon) if applying normalization.
4. Click on any point in the uploaded image.
5. The coordinates of your click will be displayed, and the result image will be shown below.
6. You can change the normalization option or strength at any time and click again to see the difference.
7. If the result image is not fully visible, you can:
- Click on the image to open it in full size in a new tab.
- Use the download button to save and view the full image locally.
""")
demo.launch(server_name="0.0.0.0")
そして、sam2を動かすサーバサイドのコードがこちらです。
サブディレクトリで動かしたい方、top_pathを変えてみるといいです。
from flask import Flask, request, jsonify
from PIL import Image, ImageDraw
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import cv2
import base64
import io
top_path = "/sam2"
app = Flask(__name__)
checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
def apply_mask(image, mask, color, alpha=0.5):
for c in range(3):
image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
return image
def draw_mask(image, mask, color, alpha=0.5, borders=True):
masked = apply_mask(image.copy(), mask, color, alpha)
if borders:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(masked, contours, -1, color, 2)
return masked
def draw_point(image, x, y, color, size=5):
cv2.circle(image, (int(x), int(y)), size, color, -1)
return image
def straighten_mask_edges(mask, epsilon:float=0.02):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
straight_mask = np.zeros_like(mask, dtype=np.uint8)
for contour in contours:
epsilon = epsilon * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
cv2.drawContours(straight_mask, [approx], 0, 1, -1)
return straight_mask.astype(bool)
def process_image(image, masks, scores, point_coords, input_labels, normalize=False, epsilon:float=0.02):
for i, (mask, score) in enumerate(zip(masks, scores)):
if normalize:
mask = straighten_mask_edges(mask, epsilon=epsilon)
color = [30/255, 144/255, 255/255] # Light blue color
image = draw_mask(image, mask, color, alpha=0.5, borders=True)
for point, label in zip(point_coords, input_labels):
color = (0, 255, 0) if label == 1 else (255, 0, 0)
image = draw_point(image, point[0], point[1], color)
return image
@app.route(f'{top_path}/predict', methods=['POST'])
def predict():
x = request.args.get('x', type=int)
y = request.args.get('y', type=int)
normalize = request.args.get('nml', type=int, default=0)
epsilon = request.args.get('epsilon', type=float, default=0.02)
if x is None or y is None:
return jsonify({"error": "Missing x or y coordinates"}), 400
if 'image' not in request.files:
return jsonify({"error": "No image file provided"}), 400
image_file = request.files['image']
image = Image.open(image_file)
image = np.array(image.convert("RGB"))
predictor.set_image(image)
input_point = np.array([[x, y]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
processed_image = process_image(image, masks, scores, input_point, input_label, normalize=bool(normalize), epsilon=epsilon)
# Convert the processed image to base64
img = Image.fromarray(processed_image)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return jsonify({"image_base64": img_str})
@app.route(f"{top_path}/ping")
def ping():
return "pong", 200
if __name__ == '__main__':
app.run(debug=True, port="5001")
そして、そのホスト側のデーモンを、自分の場合はsam2フォルダに直接置いて動かしたので、変えたい方は、パスを変えてみるといいと思います。
SAM2の入れ方
python3 -m venv venv
source venv/bin/activate
git clone https://github.com/facebookresearch/segment-anything-2.git
cd segment-anything-2 & pip install -e .
cd checkpoints && \
./download_ckpts.sh && \
cd ..
そして、Gradioを動かしてWebUI上でこんなことができるようになります。
マウスをクリックした箇所に、Segment Anything 2 (SAM2)で予測させながら、結果をほぼリアルタイムに確認することが可能となります。
エッジ部分の処理をもう少しスムーズ化したかったので、検証したいことをパラメータで調整しながら結果を素早く確認できるツールとして、悪くないじゃないかと思いました。
こちらのレポジトリで少しずつ更新を加えるので、興味ある方watchして頂けると幸いです。
(2点までのマルチポイントpromptの機能追加と、box選択の機能追加を加えてます)
弊社ジーニーでは常時採用を行っておりますので、最先端の生成AIに限らず、機械学習、データ、BI、アドテク、マーテクに興味をお持ちの方は、こちらをチェックしていただけますと幸いです。
最先端技術に情熱を注ぐ仲間たちと、共に革新的なソリューションを生み出していくことを楽しみにしています。ぜひ私たちの挑戦にご参加ください。
Discussion