🙆

Pytorchによるセグメンテーション入門

2024/03/10に公開

サンプルコード全体

今回はsegmentation_models_pytorchを使用します。このライブラリにはUnetやDeeplabV3などのセグメンテーションモデルを簡単に作成することができるcreate_modelという関数があり、モデルの中身をよく知らなくてもセグメンテーションタスク用の深層学習モデルを作成できます。素晴らしいですね。
今回はモデルの入出力のサイズ関係などを確認するためのコードを書きました。ご自身の環境で実行して動作を確認してみてください。このコードはcudaなしでコピペでも動きます。
以下のコードは学習をするコードではありませんので注意してください。

import segmentation_models_pytorch as smp
import torch

"""
segmentation modelを作成
in_channelsは入力画像チャネルRGBの3
classesは予測するマスクの数(今回は4)
入力サイズ: [batch_size, in_channels, height, width]
出力サイズ: [batch_size, classes, height, width]
"""
model = smp.create_model(
    arch="DeepLabV3",
    encoder_name="resnet34",
    in_channels=3,
    classes=4,
)
# モデルの概要をprint
print(f"モデルの概要: {model}")

# [batch_size, channel(RGB), height, width]のinputダミーデータを作成
input_data = torch.randn(16, 3, 256, 256)
print(f"inputダミーデータ: {input_data.shape}")

# ダミーデータをモデルに入力して生の予測マスク(logits_mask)を出力
predicted_logits_mask = model(input_data)
print(f"予測マスクの形: {predicted_logits_mask.shape}")

# multilabelのデータセットの場合
print("multilabelの場合")
"""
Dice Lossの設定
mode: 正解マスクがクラスごとに独自のチャネルを持ち、クラスに属さないピクセルが0、属すピクセルが1の場合はmultilabel
(正解マスクの全てのクラスが1枚の画像の画素値として表現される場合はmulticlass)
from_logits: モデルの予測が活性化される前の生データであればTrue,
    multilabelの場合かつ各マスクについて0~1の範囲の確率マスクになっていればFalse,
    multiclassの場合かつマスクの各ピクセルがクラス値になっていればFalse
"""
dice_loss_multilabel = smp.losses.DiceLoss(
    mode=smp.losses.MULTILABEL_MODE, from_logits=True
)


# [batch_size, classes, height, width]の正解マスクのサンプルを作成(multilabel用)
ground_truth_mask_multilabel = torch.randint(0, 2, (16, 4, 256, 256))
print(f"multilabelの正解マスクの形: {ground_truth_mask_multilabel.shape}")

# 0batch目の0class目のマスク[height, width]の中の要素数を確認
output, counts = torch.unique(ground_truth_mask_multilabel[0][0], return_counts=True)
print(
    f"0batch目の0class目のマスクの中の要素数: {dict(zip(output.tolist(), counts.tolist()))}, "
    f"合計要素数: {counts.sum()}"
)

# lossの計算, 引数は予測、正解の順
loss_multilabel = dice_loss_multilabel(
    predicted_logits_mask, ground_truth_mask_multilabel
)
print(f"multilabel Dice Loss: {loss_multilabel}")

print("")

# multiclassのデータセットの場合
print("multiclassの場合")
# Dice Lossの設定
dice_loss_multiclass = smp.losses.DiceLoss(
    mode=smp.losses.MULTICLASS_MODE, from_logits=True
)

# [batch_size, 1, height, width]の正解マスクのサンプルを作成(multiclass用)
ground_truth_mask_multiclass = torch.randint(0, 4, (16, 1, 256, 256))
print(f"multiclassの正解マスクの形: {ground_truth_mask_multiclass.shape}")

# 0batch目のマスク[height, width]の中の要素数を確認
output, counts = torch.unique(ground_truth_mask_multiclass[0][0], return_counts=True)
print(
    f"0batch目のマスクの中の要素数: {dict(zip(output.tolist(), counts.tolist()))}, "
    f"合計要素数: {counts.sum()}"
)

loss_multiclass = dice_loss_multiclass(
    predicted_logits_mask, ground_truth_mask_multiclass
)
print(f"multiclass Dice Loss: {loss_multiclass}")

セグメンテーションとは

セグメンテーションは正解マスクの用意の仕方によって様々なものがあります。

  • Binary segmentation
    1つのクラスのマスクを予測する一番シンプルなセグメンテーションです。
    1つの画像に複数のクラスがあるときには対応できません。
    つまり、ピクセルが 1 とラベル付けされる唯一のクラスがあり、残りのピクセルは背景で、 0 とラベル付けされています。
    正解マスクの形状は(Batch_size, Height, Width)となります。
  • Multilabel segmentation
    Binary segmentationの単純な拡張で、クラス数を任意のClasses = 1..N に拡張しただけです。
    互いのクラスは相互に排他的でなく、Classesの次元で独自のチャネルを持っています。
    正解マスクの形状は(Batch_size, Classes, Height, Width)となります。
    後に示すMulticlass segmentationではクラスの重なりが許されませんが、Multilabelでは表現可能です。
  • Multiclass segmentation
    Classes = 1..N クラスがありますが、Multilabel segmentationと異なり、クラスは相互に排他的で、すべてのピクセルがこれらの値でラベル付けされた1枚のマスクを予測します。
    ターゲットマスクの形状は(Batch_size, Height, Width)となります。
    一枚のマスクに複数のカラーで塗り潰しが行われたような画像を出力します。これがセマンティックセグメンテーションと呼ばれています。一般的にBackgroundを0とした輝度値情報としてクラスを有します。
    インスタンスセグメンテーションやパノプティックセグメンテーションはセマンティックセグメンテーションの派生系です。

この記事ではMultilabel segmentationとMulticlass segmentationを取り扱います。基礎的なセグメンテーションであり、これを通らずしてインスタンスセグメンテーションはできません。基礎的だからこそデータセットの作成自体が容易であり、シンプルであることが利点です。研究などにセグメンテーションを取り入れる場合は十分に実用的です。

サンプルコードの詳細

モデルの作成

"""
segmentation modelを作成
in_channelsは入力画像チャネルRGBの3
classesは予測するマスクの数(今回は2)
入力サイズ: [batch_size, in_channels, height, width]
出力サイズ: [batch_size, classes, height, width]
"""
model = smp.create_model(
    arch="DeepLabV3",
    encoder_name="resnet34",
    in_channels=3,
    classes=4,
)
# モデルの概要をprint
print(f"モデルの概要: {model}")

この部分ではアーキテクチャにDeepLabV3、バックボーン(エンコーダ)にresnet34を使用したモデルを作成します。文字列で指定できるので簡単ですね。

アーキテクチャ一覧

[Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]

エンコーダ一覧

['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext101_32x16d', 'resnext101_32x32d', 'resnext101_32x48d', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'inceptionresnetv2', 'inceptionv4', 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', 'mobilenet_v2', 'xception', 'timm-efficientnet-b0', 'timm-efficientnet-b1', 'timm-efficientnet-b2', 'timm-efficientnet-b3', 'timm-efficientnet-b4', 'timm-efficientnet-b5', 'timm-efficientnet-b6', 'timm-efficientnet-b7', 'timm-efficientnet-b8', 'timm-efficientnet-l2', 'timm-tf_efficientnet_lite0', 'timm-tf_efficientnet_lite1', 'timm-tf_efficientnet_lite2', 'timm-tf_efficientnet_lite3', 'timm-tf_efficientnet_lite4', 'timm-resnest14d', 'timm-resnest26d', 'timm-resnest50d', 'timm-resnest101e', 'timm-resnest200e', 'timm-resnest269e', 'timm-resnest50d_4s2x40d', 'timm-resnest50d_1s4x24d', 'timm-res2net50_26w_4s', 'timm-res2net101_26w_4s', 'timm-res2net50_26w_6s', 'timm-res2net50_26w_8s', 'timm-res2net50_48w_2s', 'timm-res2net50_14w_8s', 'timm-res2next50', 'timm-regnetx_002', 'timm-regnetx_004', 'timm-regnetx_006', 'timm-regnetx_008', 'timm-regnetx_016', 'timm-regnetx_032', 'timm-regnetx_040', 'timm-regnetx_064', 'timm-regnetx_080', 'timm-regnetx_120', 'timm-regnetx_160', 'timm-regnetx_320', 'timm-regnety_002', 'timm-regnety_004', 'timm-regnety_006', 'timm-regnety_008', 'timm-regnety_016', 'timm-regnety_032', 'timm-regnety_040', 'timm-regnety_064', 'timm-regnety_080', 'timm-regnety_120', 'timm-regnety_160', 'timm-regnety_320', 'timm-skresnet18', 'timm-skresnet34', 'timm-skresnext50_32x4d', 'timm-mobilenetv3_large_075', 'timm-mobilenetv3_large_100', 'timm-mobilenetv3_large_minimal_100', 'timm-mobilenetv3_small_075', 'timm-mobilenetv3_small_100', 'timm-mobilenetv3_small_minimal_100', 'timm-gernet_s', 'timm-gernet_m', 'timm-gernet_l', 'mit_b0', 'mit_b1', 'mit_b2', 'mit_b3', 'mit_b4', 'mit_b5', 'mobileone_s0', 'mobileone_s1', 'mobileone_s2', 'mobileone_s3', 'mobileone_s4']

in_channelは通常のRGB画像をモデルに入力する場合3になります。Grayスケール画像の場合は1です。ことのき、画像のwidth, heightは関係ありありません。

classesはセグメンテーションするクラスの数です。

ダミーデータについて

# [batch_size, channel(RGB), height, width]のinputダミーデータを作成
input_data = torch.randn(16, 3, 256, 256)
print(f"inputダミーデータ: {input_data.shape}")

これはPythonのデータローダを作成する際にBatch sizeを16にしたデータを模擬したものです。実際のデータローダはデータセットをBatchサイズごとに分割してモデルに渡しやすくするためのものです。学習の際には入力となる画像と正解となるマスクのペアを返すデータセットを作成する必要があります。

lossの計算について

# multilabelのデータセットの場合
print("multilabelの場合")
"""
Dice Lossの設定
mode: 正解マスクがクラスごとに独自のチャネルを持ち、クラスに属さないピクセルが0、属すピクセルが1の場合はmultilabel
(正解マスクの全てのクラスが1枚の画像の画素値として表現される場合はmulticlass)
from_logits: モデルの予測が活性化される前の生データであればTrue,
    multilabelの場合かつ各マスクについて0~1の範囲の確率マスクになっていればFalse,
    multiclassの場合かつマスクの各ピクセルがクラス値になっていればFalse
"""
dice_loss_multilabel = smp.losses.DiceLoss(
    mode=smp.losses.MULTILABEL_MODE, from_logits=True
)

...

# lossの計算, 引数は予測、正解の順
loss_multilabel = dice_loss_multilabel(
    predicted_logits_mask, ground_truth_mask_multilabel
)
print(f"multilabel Dice Loss: {loss_multilabel}")

lossの計算にはDice lossを用いました。Dice lossは一般的に使用されるlossで正解マスクと予測マスクの一致度合いを0~1の範囲表現します。Dice loss = 1 - Dice score という関係です。
smp.losses.DiceLossは通称soft dice scoreという微分可能に拡張されたscoreを使用しています。backwordするには微分可能である必要があるからです。soft diceは予測マスクが2値マスクとなっていない生のlogits_maskを入力するだけで計算できます。multiclassモードの場合は正解マスクを1枚で用意しているはずですが、class毎にdiceを計算するためにonehotエンコーディングしてから計算するようになっています。

Discussion