TensorFlowでカスタム訓練ループをfitに組み込むための便利な書き方
TensorFlow(Keras)でバッチ単位の訓練をカスタマイズしつつ、fitで訓練する便利な書き方を見つけたので紹介していきます。
きっかけ
Kerasの作者のツイッター見てたら面白い書き方していました。
例はVAEでしたが、もっと簡単に超解像の問題で追ってみました。
元のNotebook:https://colab.research.google.com/drive/1veeJMSRE7yzE3q-OulyPqbsuAuk00yxK?usp=sharing
環境:TensorFlow 2.3.1
問題意識
TensorFlowでちょっと複雑なモデルを訓練すると、独自の訓練ループを書いたほうが便利なことがあります。GANやVAEのように複数のモデルを訓練するケースがそうですね。次のような書き方です。
@tf.function # 高速化のためのデコレーター
def train_on_batch(X, y):
with tf.GradientTape() as tape:
pred = model(X, training=True)
loss_val = loss(y, pred)
graidents = tape.gradient(loss_val, model.trainable_weights)
optim.apply_gradients(zip(graidents, model.trainable_weights))
acc.update_state(y, pred)
return loss_val
自分が昔書いた記事のカスタム訓練ループから、バッチ単位の関数です。データ全体をforループでバッチ単位に切り出して、
for step, (X, y) in enumerate(trainset):
loss_val = train_on_batch(X, y)
こんな感じに回す、PyTorchみたいな書き方です。この書き方は便利ですが、TensorFlow(Keras)のいいところを損なっています。それは、
「Kerasの便利なcompile→fitというAPIを使えない」
ということです。バッチ単位の処理をカスタムしつつ、fitで訓練したいという需要に応えたのが今回紹介する書き方です。
モデル設定
今回はCIFAR-10の超解像の訓練をします。超解像のようにロスとは別の評価関数(PSNR)を使う問題のほうが、この書き方のありがたみがわかります。CIFAR-10でいきます。
- 入力:16×16の画像
- 出力:32×32の画像
入力画像を2倍に拡大する超解像の問題です。論文ではSingle Image Super-Resolution(SISR)と言われる問題です。
もともとCIFAR-10は32×32の解像度なので、半分のサイズ(16×16)にリサイズしたものを入力に入れ、オリジナル画像を教師データとして訓練すればいいだけです。問題としてはシンプルです。
PSNR
超解像の問題ではPSNRという評価指標をよく使います。単位はdBです。ロスの値を直接見るよりも、PSNRを見たほうが直感的にイメージしやすくなります。PSNRは以下の定義です。
PSNRの実装はTensorFlowの組み込み関数を使いましょう。
tf.image.psnr(highres_pred, highres_true, 1.0) # 1.0=MAX_I
これでサンプル単位のPSNRが計算されます。ホワイトノイズに対して計算すると次のようになります。
x = np.random.uniform(size=(5, 32, 32, 3))
y = np.random.uniform(size=(5, 32, 32, 3))
psnr = tf.image.psnr(x, y, 1.0)
print(psnr)
# tf.Tensor([7.8894033 7.6707745 7.8617544 7.8076196 7.8329306], shape=(5,), dtype=float32)
コード
先に全体のコードを示します
import tensorflow as tf
import tensorflow.keras.layers as layers
import numpy as np
# CuDNNの初期化エラー対策のためにメモリを制限する
# 参考:https://qiita.com/masudam/items/c229e3c75763e823eed5
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
class SuperResolutionModel(tf.keras.models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 複数モデルを入れ子にすることもOK
self.model = self.create_model()
# トラッカーを用意する(訓練、テスト共通で良い)
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.psnr_tracker = tf.keras.metrics.Mean(name="psnr")
def create_model(self):
inputs = layers.Input((16, 16, 3))
x = layers.Conv2D(64, 3, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.UpSampling2D(2)(x)
x = layers.Conv2D(32, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(3, 3, padding="same", activation="sigmoid")(x)
return tf.keras.models.Model(inputs, x)
# なくてもエラーは出ないが、訓練・テスト間、エポックの切り替わりで
# トラッカーがリセットされないため、必ずmetricsのプロパティをオーバーライドすること
# self.reset_metrics()はこのプロパティを参照している
@property
def metrics(self):
return [self.loss_tracker, self.psnr_tracker]
def train_step(self, data):
low_res_input, high_res_gt = data
with tf.GradientTape() as tape:
high_res_pred = self.model(low_res_input)
loss = tf.reduce_mean(tf.abs(high_res_gt - high_res_pred))
# 全体(self)に対する偏微分か、特定モデル(self.model)に対する微分かは場合により変わる
# このケースではどちらでも同じだが、GANでは使い分ける必要がある
grads = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
psnr = tf.reduce_mean(tf.image.psnr(high_res_pred, high_res_gt, 1.0))
# エポックの切り替わりのトラッカーのリセットは、self.reset_metrics()で自動的に行われる
self.loss_tracker.update_state(loss)
self.psnr_tracker.update_state(psnr)
return {
"loss": self.loss_tracker.result(),
"psnr": self.psnr_tracker.result(),
}
def test_step(self, data):
low_res_input, high_res_gt = data
high_res_pred = self.model(low_res_input)
loss = tf.reduce_mean(tf.abs(high_res_gt - high_res_pred))
psnr = tf.reduce_mean(tf.image.psnr(high_res_pred, high_res_gt, 1.0))
# 訓練・テストの切り替わりのトラッカーのリセットは、self.reset_metrics()で自動的に行われる
self.loss_tracker.update_state(loss)
self.psnr_tracker.update_state(psnr)
return {
"loss": self.loss_tracker.result(),
"psnr": self.psnr_tracker.result(),
}
def main():
(X_train, _), (X_test, _) = tf.keras.datasets.cifar10.load_data()
# high res images
X_train_highres = X_train.astype(np.float32) / 255.0
X_test_highres = X_test.astype(np.float32) / 255.0
# low res images
X_train_lowres = tf.image.resize(X_train_highres, (16, 16))
X_test_lowres = tf.image.resize(X_test_highres, (16, 16))
model = SuperResolutionModel()
model.compile(optimizer=tf.keras.optimizers.Adam())
model.fit(X_train_lowres, X_train_highres,
validation_data=(X_test_lowres, X_test_highres),
batch_size=128, epochs=10)
if __name__ == "__main__":
main()
モデルは適当に作りました(軽いのでCPUでも訓練できます)。ポイントをかいつまんで説明します。
トラッカーを用意する
コンストラクタのここに注目。
# トラッカーを用意する(訓練、テスト共通で良い)
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.psnr_tracker = tf.keras.metrics.Mean(name="psnr")
これはロスの値と、PSNRの値を記録するためのトラッカーです。バッチ単位で各値が積み重なっていき、fitのログではバッチ間の平均が表示されます。metrics.Mean
としているのはバッチ間の平均を取るためです。
このトラッカーをtrain_step, test_step
でアップデートしていきます。
train_step, test_step
この関数は、Kerasのfit
やevaluate
の中で呼ばれるバッチ単位の処理を記述する部分です。ここをカスタマイズすることで、fitのAPIを利用しつつ独自の処理を定義できます。
def train_step(self, data):
low_res_input, high_res_gt = data
with tf.GradientTape() as tape:
high_res_pred = self.model(low_res_input)
loss = tf.reduce_mean(tf.abs(high_res_gt - high_res_pred))
grads = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
psnr = tf.reduce_mean(tf.image.psnr(high_res_pred, high_res_gt, 1.0))
self.loss_tracker.update_state(loss)
self.psnr_tracker.update_state(psnr)
return {
"loss": self.loss_tracker.result(),
"psnr": self.psnr_tracker.result(),
}
細かい書き方はこれまでのカスタム訓練と同じです。異なる点がいくつかあります。
1つ目は、オプティマイザを外部から与えず、モデルのもの(インスタンス変数)を使うこと。オプティマイザはモデル内で一切記述していませんが、compile
が呼ばれたときに自動で登録されます。
2つ目は、train_stepを@tf.function
のデコレーターで囲まなくてよいこと。このデコレーターは訓練をグラフモードで高速に実行するために必要ですが、そもそもmodelのfitがデフォルトでグラフモードとして実行されるため、ステップ単位の処理を書くだけでOKです。fit自体が@tf.function
で囲まれているようなものです。
metricsのプロパティを実装する
これはなくても動いてしまいますが、全てのトラッカーをリストで入れるように実装してください。後述のトラッカーのリセットの関係で必要になります。
@property
def metrics(self):
return [self.loss_tracker, self.psnr_tracker]
ログの結果をdictで返す
見ればそのとおりですが、Kerasのログに対応する値を書いていきます。
self.loss_tracker.update_state(loss)
self.psnr_tracker.update_state(psnr)
return {
"loss": self.loss_tracker.result(),
"psnr": self.psnr_tracker.result(),
}
出力はこうなります。
391/391 [==============================] - 3s 9ms/step - loss: 0.0559 - psnr: 23.7072 - val_loss: 0.0414 - val_psnr: 25.4800
Epoch 2/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0375 - psnr: 26.1364 - val_loss: 0.0365 - val_psnr: 26.3191
Epoch 3/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0354 - psnr: 26.5446 - val_loss: 0.0343 - val_psnr: 26.7604
Epoch 4/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0343 - psnr: 26.7636 - val_loss: 0.0338 - val_psnr: 26.8522
Epoch 5/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0339 - psnr: 26.8342 - val_loss: 0.0333 - val_psnr: 26.9524
Epoch 6/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0335 - psnr: 26.9146 - val_loss: 0.0334 - val_psnr: 26.9208
Epoch 7/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0332 - psnr: 26.9717 - val_loss: 0.0328 - val_psnr: 27.0393
Epoch 8/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0330 - psnr: 27.0053 - val_loss: 0.0326 - val_psnr: 27.0906
Epoch 9/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0328 - psnr: 27.0457 - val_loss: 0.0324 - val_psnr: 27.1292
Epoch 10/10
391/391 [==============================] - 3s 8ms/step - loss: 0.0328 - psnr: 27.0496 - val_loss: 0.0329 - val_psnr: 27.0406
fitでプログレスバーをいい感じに出してくれますし、訓練データに対してもPSNRを計算してくれるのがとてもいいですね。
validationの部分は「val_」と自動的にプレフィックスをつけてくれます。
Q1:評価関数のリセットはしているのですか?
ここで疑問なのは「評価関数(トラッカー)って訓練テスト共通で用意しているけど、訓練・テストの切り替わりや、エポックの切り替わりでバッチ単位のログをリセットしているのですか?」ということです。これ気になってTensorFlowのソースを読んでみました。
結論は、metricsのプロパティにトラッカーが登録されていれば、リセットの処理を書かなくても自動的にリセットされます。もともとtf.keras.model.Models
にはreset_metrics
という関数が組み込まれています(参考)。エポックの開始時、訓練・テストの切替時にreset_metrics
が自動的に呼び出されます。この関数の実装を見てみると、
for m in self.metrics:
m.reset_states()
とありました。ここで参照しているのはModelのmetricsのプロパティです。したがって、metricsにトラッカーを登録する必要があるのです。
もし登録しなくてもエラーは特に出ませんが、リセットが行われずに訓練・テスト間でログが使い回されるはずです。ここだけ注意が必要です。
Q2:self.trainable_weightsとself.model.trainable_weightsって何が違うんですか?
このケースでは1つしかモデルがないので特に変わりませんが、GANのように複数のモデルをfitで扱うケースだと意識する必要があります。
前者は入れ子にしているモデルの訓練可能な係数全て、後者は特定モデルの訓練可能な係数全てを表します。GANの場合使い分けが必要ですね。
Q3:ずっとfitにtf.functionがかかているとデバッグで困る。デバッグのときはEagerモードで動かしたい
コンパイル時にrun_eagerly=True
と指定しましょう。
model.compile(optimizer=tf.keras.optimizers.Adam(), run_eagerly=True)
これでEagerモードに簡単に切り替えられます。train_stepの中でprintをするとテンソルの値が出てくるのでデバッグに最適です。
まとめ
バッチ単位の訓練のカスタマイズと、fitのAPIは両立できるということが示せました。これはかなり便利な書き方だと思います。ぜひ活用してみてください。
Discussion