🤡
SAM3を試す (zero-shot instance segmentation)
これは何
先日metaより発表されたSAM3をローカル環境で推論する.公式のExampleが微妙に使いづらいため.
動画推論はここでは扱わない.
VRAMは8000MiBほど使う.
環境構築
-
以下をクローンする
https://github.com/facebookresearch/sam3 -
pyproject.tomlをちょっと変更
diff --git a/pyproject.toml b/pyproject.toml
index e4998de..aede923 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,7 @@ name = "sam3"
dynamic = ["version"]
description = "SAM3 (Segment Anything Model 3) implem
entation"
readme = "README.md"
-requires-python = ">=3.8"
+requires-python = ">=3.12"
license = {file = "LICENSE"}
authors = [
{name = "Meta AI Research"}
@@ -33,6 +33,16 @@ dependencies = [
"iopath>=0.1.10",
"typing_extensions",
"huggingface_hub",
+ "einops>=0.8.1",
+ "decord>=0.6.0",
+ "pycocotools>=2.0.10",
+ "psutil>=7.1.3",
+ "opencv-python-headless>=4.11.0.86",
+ "matplotlib>=3.10.7",
+ "pandas>=2.3.3",
+ "scikit-learn>=1.7.2",
+ "scikit-image>=0.25.2",
]
[project.optional-dependencies]
@@ -82,8 +92,8 @@ train = [
"Homepage" = "https://github.com/facebookresearch/sam3"
"Bug Tracker" = "https://github.com/facebookresearch/sam3/issues"
-[tool.setuptools]
-packages = ["sam3", "sam3.model"]
+[tool.setuptools.packages.find]
+include = ["sam3*"]
- uv sync
uv sync
-
重要: huggingface上の重みにアクセスリクエストし,HuggingFaceHub-CLIでログインする
具体的には
- web上でhuggingfaceのアカウントログイン
- web上のrepoでアクセスリクエスト (重みをダウンロードしようとすれば良い)
- 承認まで暫し待機 (10分程度)
- 承認後,
hf auth loginでログイン
以下公式READMEから引用:
Hugging Face repo. Once accepted, you need to be authenticated to download the checkpoints. You can do this by running the following steps (e.g. hf auth login after generating an access token.)
推論
以下のscriptで推論する.重みは勝手にダウンロードされる.
EDIT ME箇所を編集されたし.
可視化とNMSはcodexに実装させている.(コードベース中にnmsの実装があるが,頻繁にOOMするため自前のものを用いている.可視化も同様に既存のものはあるが,cv2を使いたかったため.)
import cv2
import numpy as np
import torch
from PIL import Image
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
def opencv_visualization(
image: np.ndarray,
masks: torch.Tensor,
boxes: torch.Tensor,
scores: torch.Tensor,
score_threshold: float = 0.0,
color: tuple = (0, 255, 0),
alpha: float = 0.4,
) -> np.ndarray:
"""Create an OpenCV visualization with masks and bounding boxes.
Args:
image (np.ndarray): RGB image array shaped (H, W, 3).
masks (torch.Tensor): Boolean masks in shape (N, 1, H, W) or (N, H, W).
boxes (torch.Tensor): Bounding boxes in xyxy format with shape (N, 4).
scores (torch.Tensor): Confidence scores for each mask.
score_threshold (float): Minimum score required to visualize an instance.
alpha (float): Opacity of the colored mask overlay.
Returns:
np.ndarray: BGR image with overlays suitable for cv2.imwrite.
Raises:
ValueError: If the number of masks, boxes, and scores does not match.
ValueError: If mask shapes cannot be aligned with the image.
"""
if masks.shape[0] != boxes.shape[0] or boxes.shape[0] != scores.shape[0]:
raise ValueError("masks, boxes, and scores must have the same length.")
height, width = image.shape[0], image.shape[1]
overlay = image.copy()
for idx in range(masks.shape[0]):
score = float(scores[idx])
if score < score_threshold:
continue
mask_np = masks[idx].detach().cpu().numpy()
if mask_np.ndim > 2 and mask_np.shape[0] == 1:
mask_np = np.squeeze(mask_np, axis=0)
if mask_np.ndim != 2:
raise ValueError("Each mask must be a 2D array.")
if mask_np.shape != (height, width):
mask_np = cv2.resize(
mask_np.astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
mask_region = mask_np > 0.5
overlay[mask_region] = (
alpha * np.array(color) + (1 - alpha) * overlay[mask_region]
).astype(np.uint8)
x0, y0, x1, y1 = boxes[idx].detach().cpu().numpy()
x0_i, y0_i = max(int(x0), 0), max(int(y0), 0)
x1_i, y1_i = min(int(x1), width - 1), min(int(y1), height - 1)
cv2.rectangle(
overlay,
(x0_i, y0_i),
(x1_i, y1_i),
color=color,
thickness=2,
)
label_text = f"{score:.2f}"
cv2.putText(
overlay,
label_text,
(x0_i, max(y0_i - 5, 0)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
1,
lineType=cv2.LINE_AA,
)
return overlay
def apply_mask_nms(
masks: torch.Tensor,
boxes: torch.Tensor,
scores: torch.Tensor,
score_threshold: float,
mask_iou_threshold: float,
box_iou_threshold: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Apply mask-based NMS to filter detections without external dependencies.
Args:
masks (torch.Tensor): Masks shaped (N, H, W) or (N, 1, H, W).
boxes (torch.Tensor): Bounding boxes in xyxy format with shape (N, 4).
scores (torch.Tensor): Confidence scores for each mask.
score_threshold (float): Minimum score to consider a detection for NMS.
mask_iou_threshold (float): IoU threshold for suppressing overlapping masks.
box_iou_threshold (float): IoU threshold for suppressing overlapping boxes.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Filtered masks, boxes, and scores.
Raises:
ValueError: If masks have unsupported dimensions.
"""
if masks.dim() == 4 and masks.shape[1] == 1:
masks_for_nms = masks[:, 0]
elif masks.dim() == 3:
masks_for_nms = masks
else:
raise ValueError("masks must have shape (N, H, W) or (N, 1, H, W).")
masks_for_nms = masks_for_nms.to("cpu").bool()
scores_cpu = scores.to("cpu")
boxes_cpu = boxes.to("cpu")
order = torch.argsort(scores_cpu, descending=True)
keep_mask = torch.zeros(len(order), dtype=torch.bool, device="cpu")
kept_masks: list[torch.Tensor] = []
kept_boxes: list[torch.Tensor] = []
for order_idx, det_idx in enumerate(order):
score = scores_cpu[det_idx]
if score < score_threshold:
continue
current_mask = masks_for_nms[det_idx]
current_box = boxes_cpu[det_idx]
if not kept_masks:
keep_mask[order_idx] = True
kept_masks.append(current_mask)
kept_boxes.append(current_box)
continue
stacked_kept = torch.stack(kept_masks, dim=0)
intersection = (stacked_kept & current_mask).sum(dim=(1, 2)).float()
union = (stacked_kept | current_mask).sum(dim=(1, 2)).float()
ious = torch.where(union > 0, intersection / union, torch.zeros_like(union))
kept_boxes_tensor = torch.stack(kept_boxes, dim=0)
x1 = torch.maximum(current_box[0], kept_boxes_tensor[:, 0])
y1 = torch.maximum(current_box[1], kept_boxes_tensor[:, 1])
x2 = torch.minimum(current_box[2], kept_boxes_tensor[:, 2])
y2 = torch.minimum(current_box[3], kept_boxes_tensor[:, 3])
inter_w = torch.clamp(x2 - x1, min=0)
inter_h = torch.clamp(y2 - y1, min=0)
inter_area = inter_w * inter_h
current_area = (current_box[2] - current_box[0]) * (
current_box[3] - current_box[1]
)
kept_areas = (kept_boxes_tensor[:, 2] - kept_boxes_tensor[:, 0]) * (
kept_boxes_tensor[:, 3] - kept_boxes_tensor[:, 1]
)
box_union = current_area + kept_areas - inter_area
box_ious = torch.where(
box_union > 0, inter_area / box_union, torch.zeros_like(box_union)
)
if torch.all(ious <= mask_iou_threshold) and torch.all(
box_ious <= box_iou_threshold
):
keep_mask[order_idx] = True
kept_masks.append(current_mask)
kept_boxes.append(current_box)
selected_indices = order[keep_mask].to(masks.device)
return masks[selected_indices], boxes[selected_indices], scores[selected_indices]
# === EDIT ME ===
prompts = ["dog", "cat"]
colors = [(255, 0, 0), (0, 255, 0)]
image_path = "sample.jpg"
# === EDIT ME ===
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model, confidence_threshold=0.3)
# Load an image
image = Image.open(image_path).convert("RGB")
inference_state = processor.set_image(image)
overlay_image = image.copy()
overlay_image = np.array(overlay_image)
overlay_image = cv2.cvtColor(overlay_image, cv2.COLOR_RGB2BGR)
for prompt, color in zip(prompts, colors):
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
masks, boxes, scores = apply_mask_nms(
masks=masks,
boxes=boxes,
scores=scores,
score_threshold=0.3,
mask_iou_threshold=0.1,
box_iou_threshold=0.1,
)
print("Image Masks shape:", masks.shape)
print("Image Boxes shape:", boxes.shape)
print("Image Scores shape:", scores.shape)
overlay_image = opencv_visualization(
image=overlay_image,
masks=masks,
boxes=boxes,
scores=scores,
score_threshold=0.3,
color=color,
alpha=0.5,
)
cv2.imwrite("visualization.png", overlay_image)
推論結果


Discussion