📚

TensorFlowでCIFAR-10を変分オートエンコーダー(VAE)する。

2023/05/20に公開

はじめに

TensorFlow のチュートリアルに、MNIST のデータセットを変分オートエンコーダー(VAE)でトレーニングするノートブックが公開されている。

https://www.tensorflow.org/tutorials/generative/cvae

次のステップに「CIFAR-10 などのほかのデータセットを使って VAE を実装してみるのもよいでしょう。」と記載がある。

この記事では、CIFAR-10 を用いて変分オートエンコーダーを実装してみる。コードの全体は以下に記載した。この記事では重要な部分について解説を行なう。

データセット

画像データを float32 に変換し、0~1 にスケールしてから用いる。

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

x_train.shape, x_test.shape # ((50000, 32, 32, 3), (10000, 32, 32, 3))

モデルの構築

VAE は以下のようなモデルである。

Encoder: 画像  → (μ, σ)
Sample: (μ, σ)  →   z
Decoder: z → 画像

潜在変数 z を 128 次元のベクトルとして、以下のように実装した。

class VAE(Model):
    def __init__(self):
        super(VAE, self).__init__()
        n = 128
        self.optimizer = tf.keras.optimizers.Adam()
        self.encoder = Sequential([
            InputLayer(input_shape=(32, 32, 3)),
            Conv2D(64, (3, 3), padding='same', activation='relu', strides=2),
            BatchNormalization(),
            Conv2D(128, (3, 3), padding='same', activation='relu', strides=2),
            BatchNormalization(),
            Conv2D(256, (3, 3), padding='same', activation='relu', strides=2),
            BatchNormalization(),
            Conv2D(512, (3, 3), padding='same', activation='relu', strides=2),
            BatchNormalization(),
            Flatten(),
            Dense(2 * n),
        ], name='encoder')
        self.decoder = Sequential([
            InputLayer(input_shape=(n,)),
            Dense(1024, activation='relu'),
            BatchNormalization(),
            Dense(4 * 4 * 256, activation='relu'),
            BatchNormalization(),
            Reshape((4, 4, 256)),
            Conv2DTranspose(256, 3, activation='relu', strides=2, padding='same'),
            BatchNormalization(),
            Conv2DTranspose(256, 3, activation='relu', padding='same'),
            BatchNormalization(),
            Conv2DTranspose(128, 3, activation='relu', strides=2, padding='same'),
            BatchNormalization(),
            Conv2DTranspose(128, 3, activation='relu', padding='same'),
            BatchNormalization(),
            Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same'),
            BatchNormalization(),
            Conv2DTranspose(64, 3, activation='relu', padding='same'),
            Conv2DTranspose(3, 3, activation='sigmoid', strides=1, padding='same'),
        ], name='decoder')
        self.loss_tracker=tf.keras.metrics.Mean(name='loss')

    def call(self, x):
        encoded = self.encoder(x)
        mu, logvar = tf.split(encoded, num_or_size_splits=2, axis=1)
        epsilon = tf.random.normal(shape=(tf.shape(mu)))
        sampled = mu + tf.exp(0.5 * logvar) * epsilon
        decoded = self.decoder(sampled)
        return decoded

    @property
    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, x_batch):
        x, x_hat = x_batch
        with tf.GradientTape() as tape:
            encoded = self.encoder(x)
            mu, logvar = tf.split(encoded, num_or_size_splits=2, axis=1)
            epsilon = tf.random.normal(shape=(tf.shape(mu)))
            sampled = mu + tf.exp(0.5 * logvar) * epsilon
            decoded = self.decoder(sampled)
            loss = custom_loss(x_hat, decoded, mu, logvar)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.loss_tracker.update_state(loss)
        return {'loss':self.loss_tracker.result()}

    def test_step(self, x_batch):
        x, x_hat = x_batch
        encoded = self.encoder(x)
        mu, logvar = tf.split(encoded, num_or_size_splits=2, axis=1)
        epsilon = tf.random.normal(shape=(tf.shape(mu)))
        sampled = mu + tf.exp(0.5 * logvar) * epsilon
        decoded = self.decoder(sampled)
        loss = custom_loss(x_hat, decoded, mu, logvar)
        self.loss_tracker.update_state(loss)
        return {'loss':self.loss_tracker.result()}

autoencoder = VAE()
autoencoder.build((None, 32, 32, 3))
autoencoder.summary()
autoencoder.encoder.summary()
autoencoder.decoder.summary()

(mu や logvar を loss 関数に渡す必要があるため、train_step を書く必要がある。なるべく tensorflow の api を使って実装を軽くする方針で実装した。)

ロス関数

オートエンコーダーでは Decoder の出力画像を入力画像が一致するように学習させる。
今回は入力画像と出力画像を交差エントロピーでロスを取る。さらにこれに加えて正則化項を加える。

L_{\text{reg}} = -{1\over 2} \sum \left(1-\mu^2-\sigma^2+\log \sigma^2\right)

この正則化項を加えることで z の潜在空間が正規分布になるように学習させることができる。ロス関数の実装は以下のよう。

def custom_loss(y_true, y_pred, mean, log_var):
    loss_rec = tf.reduce_mean(
        tf.reduce_sum(
            tf.keras.losses.binary_crossentropy(y_true, y_pred)
            , axis = (1,2)
        )
    )
    log_reg = tf.reduce_mean(
        tf.reduce_sum(
            -0.5 * (1 + 2 * log_var - tf.square(mean) - tf.exp(2 * log_var))
            , axis=1
        )
    )
    return loss_rec + log_reg

学習

autoencoder.compile()
autoencoder.fit(x_train, x_train, epochs=100, shuffle=True, batch_size=128, validation_data=(x_test, x_test))

結果

このエンコーダーに test データを通してみると、以下のようになった。


エンコーダーの入力


エンコーダーの出力

学習は進んでいるが画像生成まではできなかった。

おわりに

TensorFlow を用いて変分オートエンコーダーを実装し、CIFAR-10 のモデルを学習させた。結果として、上手に画像生成はできなかった。もっと変数を増やしたモデルを使って学習すると生成できるかもしれない。

上手に出来たこととしては、tensorflow の compile や fit の api を使うことで、複雑な学習を tensorflow のやり方に乗せて学習を回すことが出来た。今後の課題としては、loss 関数の数式の部分をしっかりと理解出来ていないので、そこは数式を追って理解しておきたい。

Discussion