📷

はじめてのセマンティックセグメンテーション(Semantic Segmentation)

2022/10/09に公開約6,800字

目的

DeepLearningの物体検知経験を積みたいと思ったので、SemanticSegmentationに取り組んでみた内容を残します。特に自身が不明だった「何を学習しているのか」「セマンティックセグメンテーションの評価指標とは」を中心に記載します。
本記事で述べていないデータ拡張等の関数はjupyter notebookをご参照ください。
https://github.com/orange7mam/PracticeSemanticSegmentation.git

読者対象

  • pytorch初学者
  • 物体検知初学者

動作環境

GoogleColaboratory
GPU TeslaT4

セグメンテーションタスクとは

バウンディングボックス(bbox)による物体検出では対象物体を矩形領域で囲い、矩形領域の座標や物体種類を回帰(や分類)を用いて推定します。
⼀⽅セグメンテーションタスクではピクセル1つ1つにおいて分類を行うことで、物体の種類を推定します。全てのピクセルで出力されるので、⼊⼒画像と同じ画像サイズの出⼒が得られます。画像内で斜めに映っている物体を検知する等、矩形領域では説明(アノテーション作業)が難しい物体を検知する際に利⽤したいですね。⾃動運転デモでよくある⾛⾏レーン検知の映像をイメージしてもらうのが分かりやすいと思います。

データセット

今回はPascalVOC2012データセットを使⽤します。
検出対象は20種類で、「背景」と「未分類」合わせると22クラスの分類問題を解くことになります。
学習⽤データ(train)と評価⽤データ(val)が最初から分けられているので、そのまま利⽤します。
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit

!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
!tar -xvf VOCtrainval_11-May-2012.tar

ディレクトリ構成

VOC2012のデータには画像分類,bboxによる物体検知用のデータセットも含まれていますので、今回以下ディレクトリのデータを使用します。

  • /ImageSets/Segmentation
    • train,valそれぞれの画像ファイル名を記載したtxtファイルが格納。最初から分けられているので、ここにあるファイル群をそのまま使用します。
  • /JpegImages
    • 入力画像となるデータがまとまっています。
  • /SegmentationClass
    • 今回使うマスク画像(教師画像)がここに格納されています。

セグメンテーションの教師データについて

通常の画像分類であれば⼊⼒の画像⼀枚に対して1つの正解ラベルが付与されます。(例えばMNISTの手書き数字識別等)
セグメンテーションではこれを各ピクセルレベルで⾏います。
そのため正解画像を読み込むと、そのshapeは(Width,Height)になります。
通常の画像はRGB情報を持つため、各ピクセルで3次元情報がありますが、マスク画像ではクラス番号の1次元しか持ちません。
ではどのように、クラスごとの⾊分けを⾏うのでしょうか?
そこで画像のカラーパレットを利⽤します。
カラーパレットは任意のインデックスと指定のRGBを1対1対応させたものです。
つまり、インデックスさえ分かれば(予測できれば)RGBに変換が可能になります。

なおカラーパレットの取得は以下コードで可能です。

color_pallete = Image.open("画像のパス").getpalette()

VOC2012のカラーパレットのサンプルは以下の通りです。
⼈のindexは15番なので、RGBに変換するとピンクで⽰されます。

Datasetクラス

Datasetクラスでは画像⾃体は保存しません。画像のパスを保存しているので、画像の呼び出し方(getitemメソッド)を記載しています。このReadのタイミングでデータ拡張等の画像前処理を加えることで、拡張後画像を毎回dataset変数に保存せずにデータの表現⼒を上げられます(online Augumentation)。 今回__getitem__メソッドでは3つの加⼯を加えています。

  • ワンホット加⼯: 今回は各ピクセルごとに分類問題を⾏うため、正解ラベルをワンホットベクトルに変換します。 例えば、みかん:0、りんご:1、オレンジ:2のような3値分類を⾏う場合、ワンホットベクトルに変換したリ ンゴは、[0,1,0]のように表されます。 DLの出⼒値を確率とみなすことで、[0.1 , 0.8 , 0.1]であればおそらくリンゴであると推定が可能になります。
  • データ拡張(augumentations):回転や拡⼤縮⼩等の加⼯
  • 正規化加⼯(preprocessing):
    • 今回のモデルではImagenetによる事前学習済みモデルを使⽤します。事前学習で使ったデータセットと同じ平均、標準偏差を⽤いて入力画像の正規化処理をします。
# データセット
class Dataset(BaseDataset):
    
    def __init__(
            self,
            ids,
            images_dir,
            masks_dir,
            augmentation=None,
            preprocessing=None,
            
    ):
        self.ids = ids
        self.images_fps = [os.path.join(images_dir, image_id)+".jpg" for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id)+".png" for image_id in self.ids]
        self.CLASSES = CLASSES #グローバルに宣言したクラス項目のリスト

        
        # クラスに対応した数字配列を用意
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in self.CLASSES]
        
        

        self.augmentation = augmentation 
        self.preprocessing = preprocessing
        self.Onehot = True


    


    def __getitem__(self, i):
        # read data
        #image = cv2.imread(self.images_fps[i])
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.asarray(Image.open(self.images_fps[i]))


        mask = Image.open(self.masks_fps[i])
        mask = np.asarray(mask)
        mask = np.where(mask == 255, len(self.CLASSES), mask)  # unlabeledのパレットインデックスを255番から最後番(今回は22番)に変更
        
       
        #onehotベクトルに変換
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        
            
        return image, mask

    def __len__(self):
        return len(self.ids)

モデル

今回はDeeplabV3+を使います。
segmentation_models.pytorchライブラリを使用します。
事前学習済み重みを利用するエンコーダー部分はご自身の実行環境に合わせて選んでください。
今回は学習時間短縮化を狙い、比較的軽量なefficientnet-b3を利用します。
Available Encoders

# モデルを宣言
ENCODER = 'efficientnet-b3'
#公式github参照 https://github.com/qubvel/segmentation_models.pytorch
ENCODER_WEIGHTS = 'imagenet'


ACTIVATION = "softmax2d"  #softmax2dは(N,C,H,W)のCの次元に対してsoftmax計算してくれる。
DEVICE = 'cuda' if torch.cuda.is_available()  else "cpu"

DECODER = "DeepLabV3Plus"
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
)
model = model.to(DEVICE)

損失関数

ワンホットベクトルを正解ラベルに使⽤しているので、クロスエントロピー誤差の利⽤が最初に考えられます。⼀⽅セマンティックセグメンテーションの場合、たいていの正解ラベル情報は背景(黒)です。そのため全て黒と推定することも⼀つの局所最適解になり得ます。セグメンテーションタスクでは不均衡データを考慮した誤差関数が望ましいです。
今回使⽤しているsegmentation_models_pytorchのloss functionには引数としてmodeがあります。”multi class”では⼀枚の画像には1種類の物体が映っている、”multilabel”では1枚の画像に複数種類の物体が映っているものとして計算⼿法を変えているので注意してください。

今回は不均衡データに対する損失関数を2点紹介します。

  • Focal loss
    (1 − p_t)^γを係数として⽤意することで、予測確率 が⼤きくなる(背景等予測しやすいピクセル)時の全体のlossに対する寄与率を⼩さくする狙いがあります。
    FocalLoss(p_t) = -(1-p_t)^\gamma log(p_t)
  • DiceLoss
    こちらはクロスエントロピーから離れた指標です。
    第⼆項分⼦の値が⼤きい時(予測結果と正解の⼀致領域が⼤きい)、全体の計算結果が⼩さくなる⽤に設計されています。
    DiceLoss = 1 - \frac{2*(Predict\land Actual)}{Predict \lor Actual}

評価指標

lossとは異なり、評価指標では大きくなって欲しい値を設定します。bboxを用いた物体検知でも使用されるIoU(Intersection over Union)を使用します。正解ラベルと予測が⼀致する⾯積割合を計算した指標です。

\begin{aligned} IoU &=\frac{Actual \land Predict}{Actual + Predict - Actual \land predict} \\ &= \frac{tp}{tp+fp+fn}\\ \end{aligned}

学習結果

今回は30Epochのみ実行します。マシンスペックや時間に余裕がある方はより大きなEpoch数で学習を実行してください。

成功例


失敗例

VOC2012に収録されている20種類の物体アノテーションデータは不均衡データです。特に人が映った画像データが400枚超と多く、人以外の物体が100枚前後であることを踏まえるとアンバランスなデータになっています。人以外の物体に対しての検知精度を向上させる必要があります。
PASCAL VOC2012 DevelopmentKit

最後に

いかがだったでしょうか?
SemanticSegmentationが何をゴールにしているのかだけではなく、メモリを大量に使用するofflineのデータ拡張方式とは異なるonlineのデータ拡張を可能にするDatasetクラスの書き方等、DeepLearningのコーディングに役立つ汎用的な知識も習得できた点が学びでした。
最後までご覧いただきありがとうございました!

参考図書

Kaggle Grandmasterに学ぶ 機械学習実践アプローチ

Discussion

ログインするとコメントできます