🙂

huggingfaceで公開されてるVision Transformer (ViT)のモデルを転移学習する

2023/05/05に公開

はじめに

画像認識のタスクに対して Transformer を使うモデル(ViT)は、大きなデータセットを長く学習させる場合において高い精度が出る。ただ、CNN ベースのモデルとは違い、ViT のは実装が難しい。さらに、自前で実装出来たとしても初期から学習して高い精度を実現するには大きなデータセットと長い学習が必要になる。
そのため、この記事では、huggingface で公開されている ViT の実装[1]とモデル[2]を用いて、転移学習を行い自分で用意した画像の分類する。

詳しいコードは以下に記載した。この記事では重要な部分だけ解説する。

データセット

この記事では以下のデータセットを用いる。

https://www.kaggle.com/datasets/muhammadhananasghar/human-emotions-datasethes

顔の表情(angry, happy, sad)の画像と、その他(nothing)の画像が入っている。

さらにフォルダの構造は以下のようになっている。

EmotionsDataset
└── data
    ├── angry
    │   ├── 0.jpg
    │   ├── 1000.jpg
    ...
    │   └── 9.jpg
    ├── happy
    │   ├── 0.jpg
    │   ├── 1000.jpg
    ...
    │   └── 9.jpg
    ├── nothing
    │   ├── 0.jpg
    │   ├── 1000.jpg
    ...
    │   └── 9.jpg
    └── sad
        ├── 0.jpg
        ├── 100.jpg
        ...
        └── 9.jpg

6 directories, 3945 files

他のデータセットに対してもこのように画像を分類すればこの記事の内容で転移学習できる。

前処理

以下のようにして画像を tf.data として扱う。

# images directory

dataset_directory = './dataset/EmotionsDataset/data/'

# create raw datasets

raw_train_dataset = tf.keras.utils.image_dataset_from_directory(
    dataset_directory,
    labels='inferred',
    label_mode='categorical',
    class_names=CLASS_NAMES,
    color_mode='rgb',
    image_size=(IM_SIZE, IM_SIZE),
    validation_split=0.2,
    seed=123,
    subset='training',
)

さらに、シャッフルして画像のスケールを変換してランダムな左右反転のレイヤを追加する。

# define augment function

augment_layers = tf.keras.Sequential([
  Resizing(IM_SIZE, IM_SIZE),
  Rescaling(1./255),
  RandomFlip("horizontal")
])

@tf.function
def augment(image, label):
  return augment_layers(image, training=True), label

# apply augment function to raw datasets

train_dataset = raw_train_dataset\
    .shuffle(buffer_size=1024, reshuffle_each_iteration = True)\
    .map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .prefetch(tf.data.AUTOTUNE)

モデル構築

transformers にある ViT のモデルを設定する。今回は事前学習されたモデルを使用する。

from transformers import TFViTModel

vit_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

これを用いてモデルを構築する。

inputs = Input(shape=(IM_SIZE, IM_SIZE, 3))
x = Resizing(224, 224)(inputs)
x = Permute((3, 1, 2))(x)
vit = vit_model(x)
x = vit.pooler_output
outputs = Dense(NUM_CLASSES, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

入力の画像のサイズは(224, 224)なので変換が必要で、channel と x, y の順番も異なるので転置が必要。vit_model の出力には last_hidden_state と pooler_output があるが、画像の分類タスクではpooler_outputを用いる。

モデルの学習

model.compile(
    optimizer= "adam",
    loss = CategoricalCrossentropy(),
    metrics = ['accuracy'],
)

history = model.fit(
    train_dataset,
    validation_data = test_dataset,
    epochs = N_EPOCHS,
    verbose = 1,
    class_weight=class_weight,
    callbacks = [EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)]
)

おわりに

ViT のモデルの学習した。このデータセットは CNN だと精度が上手く出ないが、事前学習した ViT のモデルを用いると精度がよくなった。また、transformers を用いると難しいモデルでも自前のモデルに簡単に組み込めることがわかった。

また、transformers は基本 PyTorch 前提で作られているみたいなので、PyTorch のエコシステムも勉強した方がいいと思った。(Tensorflow の方が API が豊富でコードの記述量が少なくて使いやすい気がする。巷では Pytorch の方が流行っている。これがなぜなのかよくわからない。)

脚注
  1. https://huggingface.co/docs/transformers/main/en/model_doc/vit ↩︎

  2. https://huggingface.co/google/vit-base-patch16-224-in21k ↩︎

Discussion