💽

マラリア細胞検知を題材に、TensorFlowのData Augmentationの方法を解説

2023/04/23に公開

はじめに

画像データを反転や回転させることでデータをかさ増しすることで、画像の AI モデルの学習の精度を向上させる事ができる。これを Data Augmentation という。

TensorFlow には Data Augmentation 用の API[1]が用意されており、簡単に Data Augmentation を実現出来る。この記事では、マラリア細胞[2]の検知モデルを学習させることを題材に TensorFlow の Data Augmentation のやり方を解説する。

この記事で用いたコードの完全版は以下に記載した。この記事では重要な部分だけ解説する。

データの解説

この記事ではマラリアに感染した細胞の画像と健康な細胞の画像が分類されたデータセットを用いる。以下で画像を表示する。

import tensorflow as tf
import tensorflow_datasets as tfds
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()

raw_dataset_, dataset_info = tfds.load('malaria', with_info=True, as_supervised=True, shuffle_files = False, split=['train'])
raw_dataset = raw_dataset_[0]

plt.figure(figsize=(8, 6))
for i, (image, label) in enumerate(raw_dataset.take(16)):
  ax = plt.subplot(4, 4, i + 1)
  plt.imshow(image)
  plt.title(dataset_info.features['label'].int2str(label))
  plt.axis('off')
plt.show()

この画像の特徴は以下である。

  1. 背景が黒塗りされている
  2. 画像のサイズは揃っていない
  3. 回転や反転してもよい

したがって、サイズを揃えて回転や反転を加え余白を黒で埋めることで画像を増やすことが出来る。

モデルの解説

次のような LeNet をベースとした基本的な CNN のモデルを学習する。

DROPOUT_RATE = 0.3
N_FILTERS = 6
KERNEL_SIZE = 3
N_STRIDES = 1
POOL_SIZE = 2
N_DENSE_1 = 128
N_DENSE_2 = 128

model = tf.keras.Sequential([
    InputLayer(input_shape = (IM_SIZE, IM_SIZE, 3)),

    Conv2D(filters = N_FILTERS , kernel_size = KERNEL_SIZE, strides = N_STRIDES , padding='valid', activation = "relu"),
    MaxPool2D (pool_size = POOL_SIZE, strides= N_STRIDES*2),

    Conv2D(filters = N_FILTERS*2 + 4, kernel_size = KERNEL_SIZE, strides=N_STRIDES, padding='valid', activation = "relu"),
    MaxPool2D (pool_size = POOL_SIZE, strides= N_STRIDES*2),

    Conv2D(filters = N_FILTERS*4 + 2, kernel_size = KERNEL_SIZE, strides=N_STRIDES, padding='valid', activation = "relu"),
    MaxPool2D (pool_size = POOL_SIZE, strides= N_STRIDES*2),

    Flatten(),

    Dense(N_DENSE_1, activation = "relu"),
    Dropout(rate = DROPOUT_RATE),
    BatchNormalization(),

    Dense(N_DENSE_2, activation = "relu"),
    BatchNormalization(),

    Dense(1, activation = "sigmoid"),
])

model.summary()

画像のサイズを揃える

はじめに画像のリサイズと正規化のみを実施したデータセットで学習を行う。

IM_SIZE = 224
resize_layer = tf.keras.Sequential([
    Resizing(IM_SIZE, IM_SIZE),
    Rescaling(1./255)
])

@tf.function
def resize(image, label):
    image = resize_layer(image, training=True)
    return image, label

resized_dataset = raw_dataset.map(resize, num_parallel_calls=tf.data.experimental.AUTOTUNE)

Resizingレイヤーを用いて、(IM_SIZE, IM_SIZE)にサイズを揃える。Rescalingレイヤーを用いて、値を[0, 1]の範囲に揃えることが出来る。

変換した結果以下のようになる。

この結果、学習は以下のようになった。

過学習が進んでいる。損失と正解率は以下のようになった。

{ "loss": 0.26634958386421204, "accuracy": 0.9477693438529968 }

画像の反転と回転

次に画像をリサイズと正規化に加えて、反転・回転させたデータセットで学習を行う。

augment_layers = tf.keras.Sequential([
    Resizing(IM_SIZE, IM_SIZE),
    Rescaling(1./255),
    RandomFlip(),
    RandomRotation(0.2, fill_mode='constant', fill_value=0),
])

@tf.function
def augment(image, label):
    image = augment_layers(image, training=True)
    return image, label

augmented_dataset = raw_dataset.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)

RandomFlipレイヤーを用いてバッチごとにランダムな反転を実現出来る。RandomRotationレイヤーを用いてバッチごとにランダムな回転を実現出来る。引数の0.2[-0.2*2*pi, 0.2*2*pi]の間でランダムな回転を加えることを指定する。。さらにfill_modefill_valueは回転によって生まれた隙間をどの値で埋めるのかを指定する。

変換した結果以下のようになる。

この結果、学習は以下のようになった。

過学習が抑えられている。むしろモデルのパラメータが足りないように見える。損失と正解率は以下のようになった。

{'loss': 0.13817277550697327, 'accuracy': 0.9535727500915527

おわりに

TensorFlow の入力パイプラインを用いると画像の Data Augmentation を簡単に実装出来る。今回のデータのように画像が回転・反転を許す場合は非常に強力で、モデルの過学習を抑えることが出来ることを確かめた。

脚注
  1. https://www.tensorflow.org/tutorials/images/data_augmentation ↩︎

  2. https://www.tensorflow.org/datasets/catalog/malaria ↩︎

Discussion