🐕

TensorflowのResNet50の事前学習モデルをU-Netに転移学習させる方法

2023/05/14に公開

はじめに

Tensorflow には ImageNet で事前学習された ResNet50[1]が付属している。ResNet はその中間レイヤーを用いることで、画像識別系の様々なタスクに応用することができる。

この記事では、この事前学習された ResNet50 のモデルを利用して U-Net のモデル構築し画像セグメンテーションのタスクを行う。この時、事前学習された Weight を使う場合と使わない場合で、精度や学習時間にどのような違いが現れるのか、比較を行う。

コードの全体は以下に配置した。この記事では重要な部分を解説する。

データセット

この記事では、以下データセットを用いる。

https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet

このデータセットは、犬や猫の画像に対して対象物と境界と背景のセグメンテーションが行われている。train データセットは 3680 枚の画像で構成されており、比較的小さいデータセットであるため、事前に学習された ResNet50 の転移学習が有効であると期待される。

データ準備

試験データを以下のように加工する。

  1. 画像を(224, 224, 3)にリサイズし、値を 0〜1 に変換。
  2. マスクを(112, 112, 1)にリサイズし、label を 0, 1, 2 に変換する。

tensorflow のみで次のように実装した。

import tensorflow_datasets as tfds
from tensorflow.keras.layers import Resizing, Rescaling

dataset, info = tfds.load('oxford_iiit_pet', with_info=True)

# prepare test dataset

@tf.function
def data_preprocess(data):
    image = data['image']
    image = Resizing(224, 224)(image)
    image = Rescaling(1./255)(image)

    mask = data['segmentation_mask'] - 1
    mask = Resizing(112, 112)(mask)

    return image, mask

processed_test_dataset = dataset['test'].map(data_preprocess, num_parallel_calls=tf.data.AUTOTUNE).cache()

Data Augmentation

学習データに対しては Data Augmentation を実施する。今回のデータセットに対しては次のような加工を実施した

  1. -30°〜30° の間でランダムに回転
  2. 水平方向に反転
  3. 0.7〜1 の間でランダムにスケールし(224, 224)をランダムに切り出す。

今回は albumentations を用いて実装した。

import albumentations as albu

# prepare train dataset

transforms = albu.Compose([
    albu.Rotate(30),
    albu.HorizontalFlip(),
    albu.RandomResizedCrop(224, 224, scale=(0.7, 1.0)),
])

def aug_albument(image, mask):
    data = transforms(image=image, mask=mask)
    aug_image = tf.convert_to_tensor(data["image"], dtype=tf.float32)
    aug_mask = tf.convert_to_tensor(data["mask"], dtype=tf.float32)
    aug_mask = tf.image.resize(aug_mask, (112, 112))
    return [aug_image, aug_mask]

def process_data(data):
    [aug_image, aug_mask] = tf.numpy_function(
        func=aug_albument,
        inp=[data["image"], data["segmentation_mask"]],
        Tout=[tf.float32, tf.float32]
    )
    return aug_image / 255., aug_mask - 1.

processed_train_dataset = dataset['train'].cache().map(process_data, num_parallel_calls=tf.data.AUTOTUNE)

モデル構築

U-Net は以下のようなモデルである。はじめに画像を Encode して特徴を抽出し、Decode してセグメンテーションを作成する。その際に、Encoder の中間レイヤーを Decoder に渡すことで、元の画像の位置関係を後ろのレイヤーまで伝達する効果や、勾配損失の軽減などが期待できるのだと思う。


U-Net の論文[2]から引用

上記の考えに基づいて、tensorflow の ResNet50 を以下のような U-Net のモデルに改造した。

左の赤で囲った部分が ResNet50 である。この部分の事前に学習されたものを用いる。右側は初期値から学習を行う。

tensorflow で実装すると次のようになる。

class UNet(Model):
    def __init__(self):
        super(UNet, self).__init__()
        self.base_model = ResNet50(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
#        self.base_model.trainable = False

        block_1_expand_relu, block_2_expand_relu, block_3_expand_relu, block_4_expand_relu, block_5_expand_relu = [
            self.base_model.get_layer(name).output for name in [
                "conv1_relu",
                "conv2_block3_out",
                "conv3_block4_out",
                "conv4_block6_out",
                "conv5_block3_out"
            ]
        ]
        self.encoder = Model(inputs=self.base_model.input, outputs=[block_1_expand_relu, block_2_expand_relu, block_3_expand_relu, block_4_expand_relu, block_5_expand_relu])

        self.UpConv_1 = Sequential([
            Conv2DTranspose(1024, 1, strides=2, activation='relu'),
            BatchNormalization(),
        ], name='UpConv_1')

        self.Conv_1 = Sequential([
            Conv2D(1024, 3, padding='same', activation='relu'),
            BatchNormalization(),
            Conv2D(1024, 3, padding='same', activation='relu'),
            BatchNormalization(),
        ], name='Conv_1')

        self.UpConv_2 = Sequential([
            Conv2DTranspose(512, 1, strides=2, activation='relu'),
            BatchNormalization(),
        ], name='UpConv_2')

        self.Conv_2 = Sequential([
            Conv2D(512, 3, padding='same', activation='relu'),
            BatchNormalization(),
            Conv2D(512, 3, padding='same', activation='relu'),
            BatchNormalization(),
        ], name='Conv_2')

        self.UpConv_3 = Sequential([
            Conv2DTranspose(256, 1, strides=2, activation='relu'),
            BatchNormalization(),
        ], name='UpConv_3')

        self.Conv_3 = Sequential([
            Conv2D(256, 3, padding='same', activation='relu'),
            BatchNormalization(),
            Conv2D(256, 3, padding='same', activation='relu'),
            BatchNormalization(),
        ], name='Conv_3')

        self.UpConv_4 = Sequential([
            Conv2DTranspose(64, 1, strides=2, activation='relu'),
            BatchNormalization(),
        ], name='UpConv_4')

        self.Conv_4 = Sequential([
            Conv2D(64, 3, padding='same', activation='relu'),
            BatchNormalization(),
            Conv2D(64, 3, padding='same', activation='relu'),
            BatchNormalization(),
        ], name='Conv_4')

        self.output_layer = Conv2D(3, 3, padding='same', activation='softmax', name='output_layer')

    def call(self, inputs):
        x_112, x_56, x_28, x_14, x_7 = self.encoder(inputs)

        x = self.UpConv_1(x_7)
        x = tf.concat([x, x_14], axis=-1)
        x = self.Conv_1(x)

        x = self.UpConv_2(x)
        x = tf.concat([x, x_28], axis=-1)
        x = self.Conv_2(x)

        x = self.UpConv_3(x)
        x = tf.concat([x, x_56], axis=-1)
        x = self.Conv_3(x)

        x = self.UpConv_4(x)
        x = tf.concat([x, x_112], axis=-1)
        x = self.Conv_4(x)

        x = self.output_layer(x)
        return x

model = UNet()
model.build((32, 224, 224, 3))

ResNet50 から必要なレイヤーのアウトプットだけ抜き出しモデルを作成する。それをベースモデルとして用いて、そのほかのレイヤーを付け加えていく。元の論文を参考に UpConv のチャネル数は concat で 1:1 で交じり合うように調整した。

モデルの学習

モデルの学習は以下のような条件で行なった。

# compile model
model.compile(
    loss = SparseCategoricalCrossentropy(),
    optimizer = Adam(),
    metrics='accuracy',
)
model.summary()

# early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=1,
    restore_best_weights=True
)

# train model
model.fit(train_dataset, epochs=100, validation_data=test_dataset, shuffle=True, callbacks=[early_stopping])

学習の結果、epoch=9 でストップした。

試験データも上手にセグメンテーションできた。犬自体は上手に識別出来ている。口に咥えている棒までは識別出来ていない。

転移学習しない場合

事前に学習された weight を使わずに学習した結果、以下のようになった。

epoch=18 でストップした。accuracy の値も転移学習した場合に比べ悪い。

輪郭をうまく捉えられていない部分がある。画像数が少ないため CNN のカーネルが画像の輪郭を捉えられるように学習できてないことが予想される。学習時間や精度どちらに置いても転移学習した場合に比べて明らかに悪い。

おわりに

tensorflow の ResNet50 を用いて U-Net のモデルを構築できた。さらに、事前に学習された weight を用いることで、事前学習少ないデータセットであっても高い精度と速い学習時間を実現できることを確認した。

tensorflow には ResNet50 以外にも EfficientNet や MobileNet の weight が公開されているので、同様の転移学習できると思う。さらには U-Net 以外にも YOLO のようなモデルについても転移学習を行えると思う。これらは今後試したい。

脚注
  1. https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet50/ResNet50 ↩︎

  2. https://arxiv.org/abs/1505.04597 ↩︎

Discussion