🔬

PyTorchとU-Netで作る領域分割モデル

2023/03/16に公開

はじめに

こんにちは。機械学習を完全に理解した中沢(@shnakazawa_ja)です[1][2]

世の中にはテーブルデータを対象とした機械学習モデルのチュートリアルは多くありますが、画像に対するものは少なく、コードまで提供されているものは更に少なくなります。そこで、本シリーズでは基本的なコンピュータビジョンモデリングの手法をPythonで実装していきます。



今回はSemantic segmentation (領域分割/セグメンテーション) を扱います。
モデルのアーキテクチャにはU-Netを採用します。10年近く前の手法ですが、シンプルなアーキテクチャながら安定した精度が出るため、セグメンテーションの第一手として今でもよく使われる手法です。本稿ではPyTorchを用いてU-Netを一から組み立てていきます[3]

実装はKaggle Notebook上で行うことで誰もが再現できるコードを目指します。主な読者として

  • 仕事や研究で画像を扱う必要が出てきた方
  • Titanicや住宅価格予測のチュートリアルは終えたが、画像コンペへの取り組み方が分からない方
  • Kaggleの画像分類コンペで使えるベースライン(のベース)を探している方

を想定しています。一方で、画像の前処理や精度を高めるための工夫などについてはスコープ外とします。

コードはGitHubおよびKaggleにて公開しています[4]

それでは、本題に入っていきましょう。

Semantic Segmentation (領域分割/セグメンテーション) と題材

セグメンテーションとは、画像の中の対象物を分類し、ピクセル単位でマークするタスクになります。コンピュータビジョンの花形の一つですね。
より具体的にはセマンティックセグメンテーションとインスタンスセグメンテーションの2種があり、本稿で取り扱うのは前者になります[5][6][7]

今回はKaggleのSartorius - Cell Instance Segmentationコンペを題材とします。

本コンペでは培養細胞の顕微鏡写真から細胞を検出し1つ1つマークすることが目的となります。
学習用の画像として600枚程度の顕微鏡写真と、それぞれの写真における細胞の位置を示す「マスク」情報が準備されています。

Notebookの紹介

ここからコードを抜粋しつつ紹介していきます[8]

コンフィグ設定

ファイルパス、ハイパーパラメータなどの設定を独立したセルで管理することにより、コードの可読性が向上し、ワークフローの整理・デバッグが容易になります。

関数の定義

ノートブック全体で何度も利用する関数はノートブックの最上部で定義しておくと便利です。ノートブックが読みやすくなるだけでなく、デバッグやメンテナンスが容易になります。

探索的データ分析(Exploratory Data analysis: EDA)

データが手元に来たらまずどんなデータがあるのか把握します。

セグメンテーションタスクで重要になるポイントが学習データでどのように領域の情報が記録されているかの確認です。場合によっては自身の使いたいモデルに応じてフォーマットを整える必要があります。

データの読み込みと成形

df = pd.read_csv(DATA_DIR + 'train.csv')
df.head()

これを見ると、細胞の領域のマスクはRun Length Encoding (RLE)[9]で記録されていることがわかります。

領域のマスクは色で塗られた"画像"として提供される場合もあれば、輪郭情報[10]で提供される場合もあります。
マスクの形式を変更する必要がある場合は、OpenCVscikit-imageなどのツールやライブラリを用いて前処理を行います。

今回はRLEからマスクの画像を作ります (Define Helper Functions内で定義)。

また、1行1画像のデータフレームがあったほうが後で便利なので、ここで作っておきます。

grouped_df = group_bboxes(df)
grouped_df.head()

代表画像の可視化

RLEをちゃんと処理できているかの確認も兼ねて、代表画像を1枚見てみましょう。

image_id = grouped_df['id'][0]
img = load_img(f'{DATA_DIR}train/{image_id}.png')
masks = df[df['id'] == image_id]['annotation'].tolist()
masked_img = create_mask_image(img, masks)
plt.figure()
plt.imshow(img)
plt.figure()
plt.imshow(masked_img)

RLEが適切に処理できていることが確認できました。また、実際の写真とマスクの関係も確認ができました。

モデルコンポーネントの定義

本稿ではPyTorchを用いて物体認識モデルを実装します。PyTorchでの機械学習は一般に以下のような流れで行います。

  1. 画像の変形・Augmentationの定義
  2. Datasetの定義
  3. Dataloaderの定義
  4. Modelの定義
  5. 学習の実行

画像の変形とAugmentation

まず最初に、画像に適用する変形とAugmentation(データ拡張)の定義を行います。

画像サイズの調整はもちろん、ぼかしや回転、平行移動を行うことで画像の枚数を擬似的に増やすことができます。注意点として、適用すべきAugmentation手法はタスク次第で、例えば「画像の向きに意味があるような画像」であれば回転は不適切な適用となります。

これまで同様、本稿でもAlbumentations[11]を用いて実装します。
セグメンテーションタスクに置いては、元画像のみでなく正解マスクも合わせて変形させる必要があります。Albumentationsを使えば簡単に行うことができます。Augmentationの定義は画像のみに対してのときと変わらず、後に定義するDatasetの中でマスクにも当てはめるという処理を書きます。

# Image Transformation & Augmentation
def transform_train():
    transforms = [
        A.Resize(256,256,p=1),
        A.HorizontalFlip(p=0.5),
        A.Transpose(p=0.5),
        ToTensorV2(p=1)
    ]
    return A.Compose(transforms)


# Validation images undergo only resizing.
def transform_valid():
    transforms = [
        A.Resize(256,256,p=1),
        ToTensorV2(p=1)
    ]
    return A.Compose(transforms)

Datasetの定義

次に、自身の持つデータの形式に合わせDatasetを定義します。また、セグメンテーションタスクではDatasetの中で「transformsをマスクにも適用する」処理を加える必要があります。
今回は学習にも推論にも同じDatasetを使いたいので、引数stageを用いて処理を分岐させます。

# Dataset
class CellDataset(Dataset):
    def __init__(self, image_ids, dataframe, data_root, transforms=None, stage='train'):
        super().__init__()
        self.image_ids = image_ids
        self.dataframe = dataframe
        self.data_root = data_root
        self.transforms = transforms
        self.stage = stage

    def __len__(self):
        return self.image_ids.shape[0]
    
    def __getitem__(self, index):
        image_id = self.image_ids[index]
        # Load images
        image  = load_img(f'{self.data_root}{image_id}.png').astype(np.float32)

        # 3 channels to 1 channel
        image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        image /= 255.0 # normalization

        # For training and validation
        if self.stage == 'train':
            # masks
            masks = self.dataframe[self.dataframe['id'] == image_id]['annotation'].tolist()
            mask_image = create_mask_image(image, masks)

            # Transform images and masks
            if self.transforms:
                transformed = self.transforms(image=image, mask=mask_image)
                image, mask_image = transformed['image'], transformed['mask']
            return image, mask_image, image_id
        
        # For test
        else:
            # Transform images
            if self.transforms:
                image =self.transforms(image=image)['image']
            
            return image, image_id

DataLoaderの定義

次に、DataLoaderを定義します。お約束の書き方です[12]

# DataLoader
def create_dataloader(grouped_df, df, trn_idx, val_idx):
    train_ = grouped_df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = grouped_df.loc[val_idx,:].reset_index(drop=True)

    # Dataset
    train_datasets = CellDataset(train_['id'].to_numpy(), df, DATA_DIR+'train/', transforms=transform_train())
    valid_datasets = CellDataset(valid_['id'].to_numpy(), df, DATA_DIR+'train/', transforms=transform_valid())

    # Data Loader
    train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True, multiprocessing_context='fork')
    valid_loader = DataLoader(valid_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False, multiprocessing_context='fork')

    return train_loader, valid_loader

Modelの定義

U-Netを構築していきます。

class DoubleConv(nn.Module):
    """DoubleConv is a basic building block of the encoder and decoder components. 
    Consists of two convolutional layers followed by a ReLU activation function.
    """    
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.double_conv(x)
        return x


class Down(nn.Module):
    """Downscaling.
    Consists of two consecutive DoubleConv blocks followed by a max pooling operation.
    """    
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        x = self.maxpool_conv(x)
        return x


class Up(nn.Module):
    """Upscaling.
    Performed using transposed convolution and concatenation of feature maps from the corresponding "Down" operation.
    """
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input tensor shape: (batch_size, channels, height, width)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x
    

class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, bilinear=False):
        super(UNet, self).__init__()
        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        
        self.down4 = Down(512,1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x

学習と評価を行う

さて、いよいよ学習です。

これまでの記事とは異なり、交差検証と本学習を共通のコードで行います[13]。交差検証で最適なアーキテクチャ、エポック数、ハイパーパラメータ等を見つけたら、それらの情報を使ってモデルを全てのデータで訓練し直します[14]

if TRAIN_ALL: # TRAIN_ALLという変数で条件を分岐。Trueなら全データで学習。
    # Train with all data
    folds = [['','']]
else: 
    # Cross validation
    folds = KFold(n_splits=FOLD_NUM, shuffle=True, random_state=SEED)\
            .split(np.arange(grouped_df.shape[0]), grouped_df['id'].to_numpy())
    
    # For Visualization
    train_loss_list = []
    valid_loss_list = []


for fold, (trn_idx, val_idx) in enumerate(folds):
    # Load Data   
    if TRAIN_ALL:
        train_datasets = CellDataset(grouped_df['id'].to_numpy(), df, DATA_DIR+'train/', transforms=transform_train())
        train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True, multiprocessing_context='fork')
    else:
        print(f'==========Cross-Validation Fold {fold+1}==========')   
        train_loader, valid_loader = create_dataloader(grouped_df, df, trn_idx, val_idx)
        # For Visualization
        valid_losses = []

    train_losses = []
    # Load model, loss function, and optimizing algorithm
    model = UNet().to(device)
    criterion = nn.BCELoss().to(device)
    optimizer = optim.SGD(model.parameters(), weight_decay=WEIGHT_DECAY, lr = LR, momentum=MOMENTUM)
    
    # Start training
    best_loss = 10**5
    for epoch in range(EPOCHS):
        time_start = time.time()
        print(f'==========Epoch {epoch+1} Start Training==========')
        model.train()
        train_loss = 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for step, (imgs, masks, image_ids) in pbar:
            imgs = imgs.to(device).float()
            # imgs = torch.squeeze(imgs)
            masks = masks.to(device).float()
            masks = masks.view(imgs.shape[0], -1, 256, 256)

            optimizer.zero_grad()
            
            output = model(imgs)
            loss = criterion(output, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # Validation
        if TRAIN_ALL == False:
            print(f'==========Epoch {epoch+1} Start Validation==========')
            
            with torch.no_grad():
                valid_loss = 0
                preds = []
                pbar = tqdm(enumerate(valid_loader), total=len(valid_loader))
                for step, (imgs, masks, image_ids) in pbar:
                    imgs = imgs.to(device).float()
                    # imgs = torch.squeeze(imgs)
                    masks = masks.to(device).float()
                    masks = masks.view(imgs.shape[0], -1, 256, 256)
            
                    val_output = model(imgs)
                    val_loss = criterion(val_output, masks)
                    
                    valid_loss += val_loss.item()
                valid_loss /= len(valid_loader)
                
        # print results from this epoch
        exec_t = int((time.time() - time_start)/60)
        if TRAIN_ALL:
            print(f'Epoch : {epoch+1} - loss : {train_loss:.4f} / Exec time {exec_t} min\n')

        else:
            print(
                f'Epoch : {epoch+1} - loss : {train_loss:.4f} - val_loss : {valid_loss:.4f} / Exec time {exec_t} min\n'
            )
            # For visualization
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
    
    if TRAIN_ALL:
        print(f'Save model trained with all data')
        os.makedirs(MODEL_DIR, exist_ok=True)
        torch.save(model.state_dict(), MODEL_DIR+'segmentation.pth')
        del model, optimizer, train_loader
    else:
        train_loss_list.append(train_losses)
        valid_loss_list.append(valid_losses)
        del model, optimizer, train_loader, valid_loader, train_losses, valid_losses
    gc.collect()
    torch.cuda.empty_cache()

if TRAIN_ALL == False:
    # Define Helper Functions内で定義
    show_validation_score(train_loss_list, valid_loss_list)

学習が収束仕切っていませんね。すなわちEpoch数が十分でないと考えられます。本稿ではこの設定のままで次のステップに進みます。
TRAIN_ALL = Trueと変えて上のセルを再実行し、全データで学習した.pthファイルを保存します[15]

推論を行う

全データを用いての学習結果が保存できたら、新しいデータに対しての推論を行ってみましょう。
前回はDatasetをテストデータのために作り直していましたが、今回はDataset()に引数 stage='test'を与えることで、同じDatasetを使いまわせるようにしています。

files = os.listdir(DATA_DIR+'test/')
image_ids = np.array([os.path.splitext(file)[0] for file in files])
ids = []
rle_test_preds = []
original_size = (704, 520) # (width, height)

# Load Data
test_datasets = CellDataset(image_ids, df, DATA_DIR+'test/', transforms=transform_valid(), stage='test')

# Data Loader
test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False, multiprocessing_context='fork')

# Load model, loss function, and optimizing algorithm
model = UNet().to(device)
model.load_state_dict(torch.load(MODEL_DIR+'segmentation.pth'))
    
# Start Inference
print(f'==========Start Inference==========')
with torch.no_grad():
    test_preds = []
    pbar = tqdm(enumerate(test_loader), total=len(test_loader))
    for step, (imgs, image_ids) in pbar:
        imgs = imgs.to(device).float()
        output = model(imgs)

        # Convert the output from PyTorch to np.array
        output = output.detach().cpu().numpy()
        
        # run length encoding
        for image_id, predicted_mask in zip(image_ids, output):
            predicted_mask = np.squeeze(predicted_mask)
            
            # resize
            predicted_mask = cv2.resize(predicted_mask, original_size)
            
            rle_mask = encode_rle(predicted_mask)
            ids.append(image_id)
            rle_test_preds.append(rle_mask)

submission_df = pd.DataFrame({
    'id': ids, 'predicted': rle_test_preds
})
print(submission_df.head())

推論結果の可視化

このままでは推論がうまくいっているのかよくわかりませんので、1つ適当な画像を可視化してみましょう。

target = submission_df.iloc[0]
img = load_img(f'{DATA_DIR}test/{target["id"]}.png')
masks = [target['predicted']]
masked_img = create_mask_image(img, masks)
plt.figure()
plt.imshow(img)
plt.figure()
plt.imshow(masked_img)

ノイジーですが、悪くない印象です。

精度を上げるために

本稿ではセグメンテーションモデルを一から構築しました。これをベースとし、精度を上げる様々な工夫を加えていくことになります。
本稿では学習が収束する前に止めてしまっているので、まずはエポック数を増やしてどこまで変わるか見てみたいですね。
その上で精度の向上を目指すと、例えば、

  • 画像の前処理
  • より幅広いAugmentationの適用
  • 異なるモデルアーキテクチャを採用する
  • ハイパーパラメータチューニング
  • 損失関数の変更
  • 2値化閾値の最適化
  • 複数モデルのアンサンブル
  • 後処理
    • ノイズを消す等
      • 例:細胞はある程度の大きさがあるはず→小さいものはカットする

などがパッと思いつくところでしょうか。こうした工夫の情報はKaggleのDiscussionからも得ることができます。
こうした工夫の有効性はタスクの特性に大きく依存しますMLFlowなどの実験管理ツールで使った設定やその時の結果を記録しながら、様々なテクニックの試行錯誤を繰り返し精度の向上を目指していきます[16]


分類モデル物体認識モデルに続き、本稿ではセグメンテーションモデルの実装を行いました。コンピュータビジョンモデリングの基礎シリーズはこれで一区切りにしたいと思います。
本シリーズがみなさんのコンピュータビジョンモデル構築の助けになることを願っています[17]

謝辞

本稿の執筆にあたり、えんがわAI研究所 所長まっちゃんさん、相場雅彰さん、品原悠杜さんにご協力いただきました。

参考情報

脚注
  1. プログラマーが使っている独特なプログラマー用語にはどんなものがありますか? ↩︎

  2. まだまだ勉強中のため、本稿には誤りや無駄も多いと思います。厳しいご意見をいただけると嬉しいです。 ↩︎

  3. 最先端の手法でなければライブラリ化されていることがほとんどですが、古典的なモデルでも自分で組んでみると色々な学びがありました。 ↩︎

  4. GitHub/Kaggleで公開しているため、本稿ではコードは一部のみの抜粋に留めます。記事が長くなりすぎますしね。 ↩︎

  5. セマンティック:画像中の全てのピクセルに対してラベルを付ける。本稿では全ピクセルに対し「細胞ですか?」と聞き、閾値以上の確度であればマークするというアプローチを取っています。 ↩︎

  6. インスタンス:画像中の物体に対してラベルをつける=物体認識+塗りつぶしというイメージ。 ↩︎

  7. セマンティックとインスタンスを組み合わせた「パノプティックセグメンテーション」という手法も。 ↩︎

  8. 前回と同様のものは省略気味で進めますが、本稿のみで完結させるため重複した内容も多々あります。 ↩︎

  9. RLEって何?にわかりやすい回答がついているKaggle Discussion ↩︎

  10. Cocoフォーマットがこの形式。提示された座標をつないでいくと1領域が囲まれる。 ↩︎

  11. このパッケージはバージョンによって書き方が結構変わる。本稿では1.3.0を使用。 ↩︎

  12. M1 Macユーザーは、 DataLoader()の引数にmultiprocessing_context='fork'の明記が必要なのでご注意ください。M1 mac でmultiprocessに失敗する問題の対処法 ↩︎

  13. これまではわかりやすさのために分けていましたが、メンテナンス性を考えると、実践では本稿の書き方が良いと思います。 ↩︎

  14. 全てのデータで学習し直すべきか、交差検証の過程で作られた複数のモデルを用いる(推論結果は複数モデルのアンサンブルとする)べきかは議論があります。実践的には推論時間や実装の手間、利用可能な計算リソースを考えていずれのアプローチを取るか決定することになるかと思います。参考:【機械学習】交差検証後の最終モデルの選び方 ↩︎

  15. コードの保存にはstate_dict()を使用。 ↩︎

  16. この試行錯誤がモデルづくりの一番の楽しみ。 ↩︎

  17. 大規模言語モデルの学習データとして使われることも期待しています。 ↩︎

Aidemy Tech Blog

Discussion