🍣

Segment Anything のチュートリアル

2024/05/21に公開

Segment Anything のチュートリアル

こんにちは、HACARUS でインターンをしている朱です。

本記事では、 Meta によって公開された高性能のセグメンテーションモデルである Segment-Anything Model の初歩的な使い方について解説します。

Segment Anythingとは

Segment-Anything Model(SAM)は、画像内の関心のあるオブジェクトを切り出すための高性能のセグメンテーションモデルです。このモデルでは、関心のあるオブジェクトを、点やバウンディングボックスなどによるプロンプトとして指定することで、汎用性が高く高品質なセグメンテーション結果を得ることができます。SAM によるセグメンテーションの流れは以下の図のようになります。

Untitled

SAM のアーキテクチャ

SAM のアーキテクチャは、画像を image encoder でテンソルに変換し、点やバウンディングボックスのプロンプトは prompt encoder でテンソルに変換します。そして、これらを mask decoder に入力することで、目的のセグメンテーションマスクを生成するという仕組みになっています。

より詳細を知りたい方は、公式の GitHub もしくは、論文 を確認してみてください。

本記事では、 SAM の初歩的な使い方について解説します。

  1. 環境設定: Python環境を構成し、PyTorch や opencv-python などの必要なライブラリと、SAMの特定の依存関係をインストールします。
  2. 点プロンプトによるセグメンテーション: 画像に対して点を指定して SAM によるセグメンテーションを実施する方法を解説します。
  3. バウンディングボックスプロンプトによるセグメンテーション: 画像に対してバウンディングボックスを指定して SAM によるセグメンテーションを実施する方法を解説します。
  4. バウンディングボックスの標準化: SAM の応用例として、アノテーション品質を向上させるアイデアを紹介します。ここでは、アノテーターごとのバウンディングボックスサイズのばらつきを自動で統一させる方法を紹介します。

SAM の準備

  1. Python 3.8 以上の環境を準備
  2. ライブラリのインストール
    1. pip install opencv-python matplotlib
    2. pip install git+https://github.com/facebookresearch/segment-anything.git
  3. notebookでの推論
    1. 学習済みモデルのチェックポイントをダウンロードします。SAM checkpoints から適切なものを選択してください。
    2. ダウンロードしたチェックポイントを使ってモデルを読み込みます。この作業には通常、モデルの構造を定義した情報と、チェックポイントファイルへのパスが必要になります。
    3. 画像をモデルが理解できる形式に変換するため、特定の前処理を含むエンコードを行います。
    4. 画像上の関心領域(ROI)を指定するため、点やボックスなどのアノテーション情報を入力します。
    5. モデルが生成したセグメンテーションマスクを画像に重ねて表示し、特定の領域をどのように認識しているかを視覚的に確かめます。

描画用関数の設定

描画のための便利な関数を以下のように定義しておきます。

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax, edgecolor='green'):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))  

画像の準備

サンプル画像を準備します。

image = cv2.imread('images/sushi.jpg')
# resize the image to the expected size if you need
image = cv2.resize(image, (800, 600))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(8,8))
plt.imshow(image)
plt.axis('on')
plt.show()

推論の準備

上記の画像を推論するためのモデルの読み込み等を行います。

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"

# device = "cuda"
device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

predictor.set_image(image) 

プロンプト1: 1つの点を指定するプロンプト

画像内で特定のオブジェクト(例: 寿司)をセグメンテーションするために、画像内の点の座標(x, y) をプロンプトとして指定できます。それぞれの点には、1 または 0 のラベルが付けられ、これによって点が前景(関心のあるオブジェクトの一部、この場合はサーモン)か、背景(オブジェクトの一部でない)かを明示します。

# サーモンの座標をプロンプトとして指定する
input_point = np.array([[530, 150]])
input_label = np.array([1])

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

実際に、上記のプロンプトで推論を実施してみましょう。multimask_output のパラメータが True(デフォルト)の場合、SAM は3つのマスクを出力として生成します。これらのマスクは、プロンプトを基にした画像の異なる解釈を示しており、プロンプトが曖昧な場合はさまざまなマスクが得られます。また、各マスクに対するモデルの信頼度を表すスコアも得られます。

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
) 
# masks shape: (3, 600, 800): (number_of_masks) x H x W

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

プロンプト2: 複数の点を指定するプロンプト

SAM では、プロンプトとして複数の点を同時に与えることも可能です。ここでは、それぞれの寿司から 1 点ずつ指定した場合の推論結果を確認してみましょう。

points = np.array([
	[530, 150], 
	[370, 180], 
	[250, 180], 
	[300, 300], 
	[520, 290], 
	[600, 250]
])
labels = np.array([1, 1, 1, 1, 1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
masks, _, _ = predictor.predict(
    point_coords=points,
    point_labels=labels,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(8,8))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(points, labels, plt.gca())
plt.show()

プロンプト3: ラベルで背景を指定するプロンプト

点を指定するプロンプトでは、点に対応するラベルとして、前景か背景かを追加情報として渡す必要があります(前景なら 1、背景なら 0 )。ここでは、サーモン部分を前景として指定し、米の部分を背景として、サーモン部分のみ抽出できるかどうか試してみます。

# 単一の点のみで、背景指定をしない場合
input_point = np.array([[530, 150]])
input_label = np.array([1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.show() 

# 米部分を背景として指定する場合
input_point = np.array([[530, 150], [500, 200]])
input_label = np.array([1, 0])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.show() 

プロンプト4: バウンディングボックスを指定するプロンプト

SAM は以下のように、プロンプトとしてバウンディングボックスの座標を x0, y0, x1, y1 の形式で指定することもできます。

input_box = np.array([370, 70, 630, 260]) # [x0, y0, x1, y1]

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.show()

プロンプト 5: バウンディングボックスと点を同時に指定するプロンプト

以下のように、バウンディングボックスと点の両方を組み合わせてプロンプトを与えることができます。ここでは、バウンディングボックスと背景の点を指定してみます。

input_box = np.array([390, 80, 600, 250]) # [x0, y0, x1, y1]
input_point = np.array([[500, 200]])
input_label = np.array([0])

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
# plt.axis('off')
plt.show()

SAM の応用例: バウンディングボックスサイズの統一化

アノテーションデータを作成する際、アノテーターごとの品質が問題になることがあります。例えば、物体検知タスクの場合、同じ物体でもバウンディングボックスの大きさはアノテータごとに異なります。ここでは、SAM を活用して、より正確なバウンディングボックスに修正することでアノテーション品質を高める方法を紹介します。具体的には、プロンプト 4 で指定したバウンディングボックスをアノテーターによるものと想定し、SAM の結果得られたマスクから新たなバウンディングボックスを計算することで、よりタイトなバウンディングボックスを得ることができます。

# find the box of the mask
mask = masks[0]
y, x = np.where(mask)
y0, y1 = y.min(), y.max()
x0, x1 = x.min(), x.max()
shrink_box = np.array([x0, y0, x1, y1])

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_box(shrink_box, plt.gca(), edgecolor='red')
# plt.axis('off')
plt.show()

※ アノテーターによって与えられた緑のバウンディングボックスと SAM によって修正された赤のバウンディングボックス

まとめ

Segment Anything を使って点やバウンディングボックスによるプロンプトを用いて、簡単に高品質のセグメンテーション結果を得ることができました。SAM は、医療画像への拡張や、CVAT の拡張機能としての組み込みなど、さまざまな応用事例も公開されており、今後も発展が期待できそうです。

References

HACARUS Tech Blog

Discussion