🦀

Segment Anything Model 2 (SAM 2)の動画データに対するセグメンテーションのチュートリアル

2024/08/09に公開

こんにちは、HACARUS でインターンをしている末原です。

今回は、画像・動画に対してゼロショットでセグメンテーションを行うことができる最新のSegment Anything Model 2(SAM 2)を紹介します!

Segment Anything Model 2 とは

SAM (Segment Anything Model) はゼロショット能力と柔軟なプロンプティングで高い評価を受けていますが、主に画像データに対応していました。ところが、2024年7月29日に Meta から公開された SAM 2 (Segment Anything Model 2) は、画像だけでなく動画データにも対応できるようになりました。この新モデルでは、単一画像の処理時間が約1/6に短縮され、セグメンテーションの精度も向上しています。さらに、アーキテクチャの変更を必要とせず、単一画像を動画の1フレームとして扱うことで、動画に対する推論も可能となっています。

SAM 2 のアーキテクチャ
SAM 2 のアーキテクチャ (SAM 2 の GitHub より引用)

Memory Bank

メモリバンクは、動画内の対象オブジェクトに関する過去の予測情報を保持します。これは最大 N 個の最新フレームのメモリを FIFO キューで管理し、プロンプトからの情報を最大 M 個のプロンプトされたフレームの FIFO キューで保存することで実現されています。実際の内部処理としては、SAM2Base クラスの num_maskmem パラメータで、メモリバンクから予測に使用するフレーム数を決定します。デフォルトではこの値は 7 に設定されています。フレームの選択は、現在のフレームの直前または直後から始め、その後 SAM2Base クラスの memory_temporal_stride_for_eval で指定した r フレームごとに選択します。 詳しいフレーム選択の処理はコードをご覧ください。

SAM 2 の動画データに対するチュートリアル

公式の GitHub リポジトリ でも動画データに対するセグメンテーションのチュートリアルは準備されていますが、動画ファイルを一旦静止画に切り出す必要があったり、いろいろと躓いた点があったので、Google Colab で動くように整理しました。今回使用したコードは Google Colab でも公開しています。非常に簡単に動かせるので興味のある方はぜひ試してみてください!

環境構築

まず SAM 2 のコードを git clone します。

!git clone https://github.com/facebookresearch/segment-anything-2.git
%cd ./segment-anything-2

次に、setup.pyを利用して SAM 2 をインストールします。

pip install -e .
上記の実行に時間がかかる場合

SAM 2 の学習済みモデルファイルをダウンロード

cd checkpoints
./download_ckpts.sh
cd ../

必要なモジュールの読み込み

import time

import os
import cv2
import torch
import numpy as np
import pandas as pd
from glob import glob
from natsort import natsorted
import matplotlib.pyplot as plt

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2_video_predictor

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# マスク描画用関数
def show_mask(mask, ax, obj_id=None):
    cmap = plt.get_cmap("tab10")
    cmap_idx = 0 if obj_id is None else obj_id
    color = np.array([*cmap(cmap_idx)[:3], 0.6])

    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

# プロンプト描画用関数
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

動画の準備

今回は、京都の嵐山付近で見かけたサワガニの動画データを使ってセグメンテーションをしてみます。動画ファイルは、以下のようにダウンロードできます。

input_video_path = "/content/kani.mp4" # 入力の動画ファイルパス
!wget 'https://drive.google.com/uc?export=download&id=1-K1D6qDkFIv3RqsLqIj0nC20NCcgaO3o' -O {input_video_path}

ただし、SAM 2 の標準的な機能では、動画ファイルのまま推論を実施することはできず、一度、ファイル名が連番になるように静止画に変換する必要があります。SAM 2 内部の load_video_frames関数は、[".jpg", ".jpeg", ".JPG", ".JPEG"] のみに対応しているため以下のように ffmpeg を使って動画から .jpg への切り出しを行います。

input_img_dir = "./input_kani/" # 入力の静止画用ディレクトリ
!ffmpeg -i {input_video_path} -r 3 -q:v 2 -start_number 0 {input_img_dir}/'%05d.jpg'


使用する動画の最初のフレーム

モデルの準備

ここでは、SAM 2 の Large モデルを使用します。build_sam2_video_predictor でモデルの読み込みを実施し、各フレームの画像に対して predictor.init_state を実行します。

# モデルのロード (Large を使用)
sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

# 初期化 + 動画の各フレームの image embedding 求める
inference_state = predictor.init_state(video_path=input_img_dir)

先頭のフレームに対して、点のプロンプトを設定し、セグメンテーションを求めてみると、うまくサワガニの領域だけマスクできていることがわかります。

## サワガニの座標をプロンプトとして指定
ann_frame_idx = 0
ann_obj_id = 1
input_point = np.array([[800, 380], [750, 360]], dtype=np.float32)
input_label = np.array([1, 1], np.int32)

## 指定したフレームに対する入力プロンプトのセグメンテーションを計算
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=input_point,
    labels=input_label,
)

## セグメンテーション結果を描画
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
image = cv2.imread(frame_names[ann_frame_idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])


先頭フレームのセグメンテーション結果

動画の各フレームに対するセグメンテーションの実行

各フレームに対する推論は以下のように実行できます。セグメンテーションの結果は video_segments に格納されていきます。

video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

次に、得られた推論結果を静止画として保存します。

# 結果の静止画を保存
plt.close("all")
for out_frame_idx in range(len(frame_names)):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.axis('off')
    plt.tight_layout(pad=0)

    # 元画像の描画
    image = cv2.imread(frame_names[out_frame_idx])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.imshow(image)

    # マスクの描画
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

    # 結果を保存
    basename = os.path.basename(frame_names[out_frame_idx])
    output_frame = os.path.join(output_img_dir, basename)
    plt.savefig(output_frame)
    plt.close()

最後に、静止画から動画に ffmpeg で変換します。

# 結果の画像から動画に変換
!ffmpeg -framerate 12 -i {output_img_dir}/%05d.jpg -c:v libx264 -r 30 -pix_fmt yuv420p {output_video_path}


セグメンテーション結果 (3MB 以下の .gif にしないといけなかったのでカクカク)

できあがった動画をみると、入力のプロンプトとして、1フレームにしか情報を与えていないにも関わらず、かなり高い精度でセグメンテーションができていることがわかりました。途中でサワガニが岩陰に隠れてしまいますが、しばらくすると再度発見できているのは驚きました。

カメラの視点が急激に変化した場合や、対象オブジェクトの動きが激しい場合、対象オブジェクトが画面外に長時間出る場合は、最初のフレームだけでなく、数フレームのプロンプトを与えることで、正確な推論が期待できるそうです。

まとめ

SAM 2 を使って、簡単に動画から高品質のセグメンテーション結果を得ることができました。ここまで性能が高いと、いろんな場面にすぐに応用できそうなので、今後、特に注目の技術のように思います。

参考文献

HACARUS Tech Blog

Discussion