TensorFlowでカスタム訓練ループをfitに組み込むための便利な書き方

2021/01/22に公開

TensorFlow(Keras)でバッチ単位の訓練をカスタマイズしつつ、fitで訓練する便利な書き方を見つけたので紹介していきます。

きっかけ

Kerasの作者のツイッター見てたら面白い書き方していました。

https://twitter.com/fchollet/status/1348333356795088897

例はVAEでしたが、もっと簡単に超解像の問題で追ってみました。

元のNotebook:https://colab.research.google.com/drive/1veeJMSRE7yzE3q-OulyPqbsuAuk00yxK?usp=sharing

環境:TensorFlow 2.3.1

問題意識

TensorFlowでちょっと複雑なモデルを訓練すると、独自の訓練ループを書いたほうが便利なことがあります。GANやVAEのように複数のモデルを訓練するケースがそうですね。次のような書き方です。

    @tf.function # 高速化のためのデコレーター
    def train_on_batch(X, y):
        with tf.GradientTape() as tape:
            pred = model(X, training=True)
            loss_val = loss(y, pred)
        graidents = tape.gradient(loss_val, model.trainable_weights)
        optim.apply_gradients(zip(graidents, model.trainable_weights))
        acc.update_state(y, pred)
        return loss_val

自分が昔書いた記事のカスタム訓練ループから、バッチ単位の関数です。データ全体をforループでバッチ単位に切り出して、

    for step, (X, y) in enumerate(trainset):
        loss_val = train_on_batch(X, y)

こんな感じに回す、PyTorchみたいな書き方です。この書き方は便利ですが、TensorFlow(Keras)のいいところを損なっています。それは、

「Kerasの便利なcompile→fitというAPIを使えない」

ということです。バッチ単位の処理をカスタムしつつ、fitで訓練したいという需要に応えたのが今回紹介する書き方です。

モデル設定

今回はCIFAR-10の超解像の訓練をします。超解像のようにロスとは別の評価関数(PSNR)を使う問題のほうが、この書き方のありがたみがわかります。CIFAR-10でいきます。

  • 入力:16×16の画像
  • 出力:32×32の画像

入力画像を2倍に拡大する超解像の問題です。論文ではSingle Image Super-Resolution(SISR)と言われる問題です。

もともとCIFAR-10は32×32の解像度なので、半分のサイズ(16×16)にリサイズしたものを入力に入れ、オリジナル画像を教師データとして訓練すればいいだけです。問題としてはシンプルです。

PSNR

超解像の問題ではPSNRという評価指標をよく使います。単位はdBです。ロスの値を直接見るよりも、PSNRを見たほうが直感的にイメージしやすくなります。PSNRは以下の定義です。

PSNR=10\log_{10}\frac{MAX_I^2}{MSE}

MSEは平均2乗誤差、MAX_Iは画素の最大値を表します。0-1スケールならMAX_I=1、0-255スケールならMAX_I=255です。このコードでは0-1スケールで説明します。

PSNRの実装はTensorFlowの組み込み関数を使いましょう。

tf.image.psnr(highres_pred, highres_true, 1.0) # 1.0=MAX_I

これでサンプル単位のPSNRが計算されます。ホワイトノイズに対して計算すると次のようになります。

x = np.random.uniform(size=(5, 32, 32, 3))
y = np.random.uniform(size=(5, 32, 32, 3))
psnr = tf.image.psnr(x, y, 1.0)
print(psnr)
# tf.Tensor([7.8894033 7.6707745 7.8617544 7.8076196 7.8329306], shape=(5,), dtype=float32)

コード

先に全体のコードを示します

import tensorflow as tf
import tensorflow.keras.layers as layers
import numpy as np

# CuDNNの初期化エラー対策のためにメモリを制限する
# 参考:https://qiita.com/masudam/items/c229e3c75763e823eed5
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

class SuperResolutionModel(tf.keras.models.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # 複数モデルを入れ子にすることもOK
        self.model = self.create_model()
        # トラッカーを用意する(訓練、テスト共通で良い)
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.psnr_tracker = tf.keras.metrics.Mean(name="psnr")

    def create_model(self):
        inputs = layers.Input((16, 16, 3))
        x = layers.Conv2D(64, 3, padding="same")(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.UpSampling2D(2)(x)
        x = layers.Conv2D(32, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Conv2D(3, 3, padding="same", activation="sigmoid")(x)
        return tf.keras.models.Model(inputs, x)

    # なくてもエラーは出ないが、訓練・テスト間、エポックの切り替わりで
    # トラッカーがリセットされないため、必ずmetricsのプロパティをオーバーライドすること
    # self.reset_metrics()はこのプロパティを参照している
    @property
    def metrics(self):
        return [self.loss_tracker, self.psnr_tracker]

    def train_step(self, data):
        low_res_input, high_res_gt = data

        with tf.GradientTape() as tape:
            high_res_pred = self.model(low_res_input)
            loss = tf.reduce_mean(tf.abs(high_res_gt - high_res_pred))
        # 全体(self)に対する偏微分か、特定モデル(self.model)に対する微分かは場合により変わる
        # このケースではどちらでも同じだが、GANでは使い分ける必要がある
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        psnr = tf.reduce_mean(tf.image.psnr(high_res_pred, high_res_gt, 1.0))

        # エポックの切り替わりのトラッカーのリセットは、self.reset_metrics()で自動的に行われる
        self.loss_tracker.update_state(loss)
        self.psnr_tracker.update_state(psnr)
        return {
            "loss": self.loss_tracker.result(),
            "psnr": self.psnr_tracker.result(),
        }

    def test_step(self, data):
        low_res_input, high_res_gt = data

        high_res_pred = self.model(low_res_input)
        loss = tf.reduce_mean(tf.abs(high_res_gt - high_res_pred))
        psnr = tf.reduce_mean(tf.image.psnr(high_res_pred, high_res_gt, 1.0))

        # 訓練・テストの切り替わりのトラッカーのリセットは、self.reset_metrics()で自動的に行われる
        self.loss_tracker.update_state(loss)
        self.psnr_tracker.update_state(psnr)
        return {
            "loss": self.loss_tracker.result(),
            "psnr": self.psnr_tracker.result(),
        }

def main():
    (X_train, _), (X_test, _) = tf.keras.datasets.cifar10.load_data()
    # high res images
    X_train_highres = X_train.astype(np.float32) / 255.0
    X_test_highres = X_test.astype(np.float32) / 255.0
    # low res images
    X_train_lowres = tf.image.resize(X_train_highres, (16, 16))
    X_test_lowres = tf.image.resize(X_test_highres, (16, 16))

    model = SuperResolutionModel()
    model.compile(optimizer=tf.keras.optimizers.Adam())
    model.fit(X_train_lowres, X_train_highres,
              validation_data=(X_test_lowres, X_test_highres),          
              batch_size=128, epochs=10)

if __name__ == "__main__":
    main()

モデルは適当に作りました(軽いのでCPUでも訓練できます)。ポイントをかいつまんで説明します。

トラッカーを用意する

コンストラクタのここに注目。

        # トラッカーを用意する(訓練、テスト共通で良い)
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.psnr_tracker = tf.keras.metrics.Mean(name="psnr")

これはロスの値と、PSNRの値を記録するためのトラッカーです。バッチ単位で各値が積み重なっていき、fitのログではバッチ間の平均が表示されます。metrics.Meanとしているのはバッチ間の平均を取るためです。

このトラッカーをtrain_step, test_stepでアップデートしていきます。

train_step, test_step

この関数は、Kerasのfitevaluateの中で呼ばれるバッチ単位の処理を記述する部分です。ここをカスタマイズすることで、fitのAPIを利用しつつ独自の処理を定義できます。

    def train_step(self, data):
        low_res_input, high_res_gt = data

        with tf.GradientTape() as tape:
            high_res_pred = self.model(low_res_input)
            loss = tf.reduce_mean(tf.abs(high_res_gt - high_res_pred))
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        psnr = tf.reduce_mean(tf.image.psnr(high_res_pred, high_res_gt, 1.0))

        self.loss_tracker.update_state(loss)
        self.psnr_tracker.update_state(psnr)
        return {
            "loss": self.loss_tracker.result(),
            "psnr": self.psnr_tracker.result(),
        }

細かい書き方はこれまでのカスタム訓練と同じです。異なる点がいくつかあります。

1つ目は、オプティマイザを外部から与えず、モデルのもの(インスタンス変数)を使うこと。オプティマイザはモデル内で一切記述していませんが、compileが呼ばれたときに自動で登録されます。

2つ目は、train_stepを@tf.functionのデコレーターで囲まなくてよいこと。このデコレーターは訓練をグラフモードで高速に実行するために必要ですが、そもそもmodelのfitがデフォルトでグラフモードとして実行されるため、ステップ単位の処理を書くだけでOKです。fit自体が@tf.functionで囲まれているようなものです。

metricsのプロパティを実装する

これはなくても動いてしまいますが、全てのトラッカーをリストで入れるように実装してください。後述のトラッカーのリセットの関係で必要になります。

    @property
    def metrics(self):
        return [self.loss_tracker, self.psnr_tracker]

ログの結果をdictで返す

見ればそのとおりですが、Kerasのログに対応する値を書いていきます。

        self.loss_tracker.update_state(loss)
        self.psnr_tracker.update_state(psnr)
        return {
            "loss": self.loss_tracker.result(),
            "psnr": self.psnr_tracker.result(),
        }

出力はこうなります。

391/391 [==============================] - 3s 9ms/step - loss: 0.0559 - psnr: 23.7072 - val_loss: 0.0414 - val_psnr: 25.4800
Epoch 2/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0375 - psnr: 26.1364 - val_loss: 0.0365 - val_psnr: 26.3191
Epoch 3/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0354 - psnr: 26.5446 - val_loss: 0.0343 - val_psnr: 26.7604
Epoch 4/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0343 - psnr: 26.7636 - val_loss: 0.0338 - val_psnr: 26.8522
Epoch 5/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0339 - psnr: 26.8342 - val_loss: 0.0333 - val_psnr: 26.9524
Epoch 6/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0335 - psnr: 26.9146 - val_loss: 0.0334 - val_psnr: 26.9208
Epoch 7/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0332 - psnr: 26.9717 - val_loss: 0.0328 - val_psnr: 27.0393
Epoch 8/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0330 - psnr: 27.0053 - val_loss: 0.0326 - val_psnr: 27.0906
Epoch 9/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0328 - psnr: 27.0457 - val_loss: 0.0324 - val_psnr: 27.1292
Epoch 10/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0328 - psnr: 27.0496 - val_loss: 0.0329 - val_psnr: 27.0406

fitでプログレスバーをいい感じに出してくれますし、訓練データに対してもPSNRを計算してくれるのがとてもいいですね。

validationの部分は「val_」と自動的にプレフィックスをつけてくれます。

Q1:評価関数のリセットはしているのですか?

ここで疑問なのは「評価関数(トラッカー)って訓練テスト共通で用意しているけど、訓練・テストの切り替わりや、エポックの切り替わりでバッチ単位のログをリセットしているのですか?」ということです。これ気になってTensorFlowのソースを読んでみました。

結論は、metricsのプロパティにトラッカーが登録されていれば、リセットの処理を書かなくても自動的にリセットされます。もともとtf.keras.model.Modelsにはreset_metricsという関数が組み込まれています(参考)。エポックの開始時、訓練・テストの切替時にreset_metricsが自動的に呼び出されます。この関数の実装を見てみると、

    for m in self.metrics:
      m.reset_states()

とありました。ここで参照しているのはModelのmetricsのプロパティです。したがって、metricsにトラッカーを登録する必要があるのです。

もし登録しなくてもエラーは特に出ませんが、リセットが行われずに訓練・テスト間でログが使い回されるはずです。ここだけ注意が必要です。

Q2:self.trainable_weightsとself.model.trainable_weightsって何が違うんですか?

このケースでは1つしかモデルがないので特に変わりませんが、GANのように複数のモデルをfitで扱うケースだと意識する必要があります。

前者は入れ子にしているモデルの訓練可能な係数全て、後者は特定モデルの訓練可能な係数全てを表します。GANの場合使い分けが必要ですね。

Q3:ずっとfitにtf.functionがかかているとデバッグで困る。デバッグのときはEagerモードで動かしたい

コンパイル時にrun_eagerly=Trueと指定しましょう。

model.compile(optimizer=tf.keras.optimizers.Adam(), run_eagerly=True)

これでEagerモードに簡単に切り替えられます。train_stepの中でprintをするとテンソルの値が出てくるのでデバッグに最適です。

まとめ

バッチ単位の訓練のカスタマイズと、fitのAPIは両立できるということが示せました。これはかなり便利な書き方だと思います。ぜひ活用してみてください。

Discussion