🤖

今更ながらVAEってこんなに凄かったの?ってなった話

2021/06/27に公開

はじめに

VAE、変分オートエンコーダのVAEです。機械学習のネットワークの一つです。
これを読まれている方は、VAEについて多少以上は御存じであるという前提でお話します。

VAEとは、端的に言えば特徴を潜在変数を介した表現に起き換える手法です。

潜在変数というのは、正体不明だがその入力を決定づける何らかの変数、といった感じに理解していればOKだと思います。

表に顕在していない、その入力を決定づける何かを、まず0まわりの値を取る自然な乱数的なサムシングとして仮定し、オートエンコーダで絞ったときの最低限の特徴がその0まわりの乱数的なサムシングのみで成り立つように設計するということです。

これ以上の詳しい説明は他に譲ります。参考としてはこちらが有名かと。

Variational Autoencoder徹底解説

VAEがこんなことできるって知ってた?

VAEによる正規化

ところでこの画像、何かわかりますか。

実はこれ、VAEを使って元の画像の「余計な回転や歪みを取り除いて」再構成したものになります。

数字の1が3つありますが、これが特に顕著で、バラバラの傾きが整えられています。

生成コードは末尾に記載します。

念のため説明を加えると、上2段がオリジナルのMNIST画像で、下2段が再構成した各画像になります。

厳密に言えば、変形を除いた各数字に対する複数のテンプレートを持っており、それと一致するようなパターンを再構成していると考えられるので、微妙に形が変わってもいます。

でもこれ、「すごくないですか」? これを見るまで私はVAEを誤解していました。

VAEは確かに「何らかの潜在変数を抽出するが、事前にここまで明示的な制御はできない」と思っていました。

例えば、潜在変数を後から覗いたときに、1番目の変数が回転のようなものを表わしていると確認できた。なので、1番目の変数を-1から1に変更しながらアニメーションさせてみると確かに回転している。などといったことはできるものの、何番目の変数が何を表わしているかはランダムだ、と思っていたんです。

VAEの潜在変数を覗いた例としては、PFNインターン(当時)の方が書かれたこちらの記事が詳しいので、ご案内しておきます。

しかしこの画像を生成した手順は違いました。

実をいうと、この画像の生成コードは3年前に投稿されたQiita記事に書かれていたもののパクりです。

該当記事は、「オートエンコーダにアフィン変換を組み込む」です。

なぜこれができるのか

端的にいうと、この処理は4つの手順から成っています。

1,画素をすべて座標に変換し
2.座標に潜在変数から抽出した変形行列を掛け
3.「変形行列を掛けて再構築したら入力画像になる」と学習し
4.再構成時のオプションとして2の手順を抜くことで変形が除去される

という形です。

つまり、普通に再構成したら入力画像になるけど、それは何かの回転とかが掛かったものだよ、と騙しうちのような手法で嘘を与え、嘘を真実にするように回転の掛かってない元の状態を学習するということ……。

言葉にすると難しいですね。

何にせよ、結果としてこの画像のようなことが実際できる、というのが全てです。

VAEによる正規化

あるいはコードを見た方がわかりやすいかもしれないのでコードを貼ります。

コード

python3 tensorflow2.5 tensorflow_probability
Google Colab辺りに投げればそのまま動くはず。(2021/06/27時点)
地味にTPU自動判定に対応しているのでTPU環境があればTPU使います。

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt


class Encoder(tf.keras.layers.Layer):
    def __init__(self):
        super(Encoder, self).__init__()
        self.layer1 = tf.keras.layers.Dense(512, activation="relu")
        self.layer2 = tf.keras.layers.Dense(512, activation="relu")
        self.layer3m = tf.keras.layers.Dense(16)
        self.layer3v = tf.keras.layers.Dense(16, activation="softplus")

    def call(self, x, *args, **kwargs):
        x = self.layer1(x)
        x = self.layer2(x)
        m = self.layer3m(x)
        v = self.layer3v(x)
        return [m, v]


class Decoder(tf.keras.layers.Layer):
    def __init__(self, h, w):
        super(Decoder, self).__init__()
        self.layer1 = tf.keras.layers.Dense(512, activation="relu")
        self.layer2 = tf.keras.layers.Dense(512, activation="relu")
        self.layer3 = tf.keras.layers.Dense(1, activation="sigmoid")
        ys = np.linspace(-1, 1, h)
        xs = np.linspace(-1, 1, w)
        uv = np.zeros((h, w, 2))
        for iy in range(0, h):
            for ix in range(0, w):
                uv[iy, ix] = np.array([ys[iy], xs[ix]])
        self.uv = tf.constant(uv, dtype=tf.float32)  # (h, w, 2)
        self.uv = tf.expand_dims(self.uv, axis=0)  # (1, h, w, 2)
        self.apply_transform = True

    def call(self, z, *args, **kwargs):
        p = z[:, -6:]  # transform params

        z = z[:, :-6]
        z = z[..., tf.newaxis, tf.newaxis]
        z = tf.transpose(z, (0, 2, 3, 1))
        uv = self.uv + tf.zeros_like(z[:, :, :, 0:1])  # broadcast
        if self.apply_transform:
            p_mat = tf.reshape(p, (-1, 2, 3))
            p_zeros = tf.zeros_like(p_mat[:, 0:1, :])
            p_mat = tf.concat([p_mat, p_zeros], axis=-2)
            p_mat = 0.1 * p_mat + tf.eye(3)
            uv1 = tf.ones_like(uv[:, :, :, 0:1])
            uv3 = tf.concat([uv, uv1], axis=-1)
            uv = tf.keras.backend.batch_dot(uv3, p_mat)[:, :, :, 0:2]

        z = tf.tile(z, (1, uv.shape[1], uv.shape[2], 1))  # broadcast
        uvz = tf.concat([uv, z], axis=-1)
        uvz = tf.reshape(uvz, (-1, uvz.shape[-1]))  # (batch * h * w, z + 2)

        x = self.layer1(uvz)
        x = self.layer2(x)
        x = self.layer3(x)
        x = tf.reshape(x, (-1, uv.shape[1], uv.shape[2], 1))
        return x


class Predictor(tf.keras.layers.Layer):
    def __init__(self):
        super(Predictor, self).__init__()
        self.layer1 = tf.keras.layers.Dense(512, activation="relu")
        self.layer2 = tf.keras.layers.Dense(10, activation="softmax")

    def call(self, x, *args, **kwargs):
        x = self.layer1(x)
        x = self.layer2(x)
        return x


class VAE(tf.keras.layers.Layer):
    def __init__(self, name="vae"):
        super(VAE, self).__init__(name=name)
        self.encoder = Encoder()
        self.decoder = Decoder(28, 28)
        self.predictor = Predictor()

    def call(self, x, *args, **kwargs):
        m, v = self.encoder(x)
        dist = tfp.distributions.Normal(m, v)
        z = dist.sample([1])
        z = z[0, :, :]
        y = self.decoder(z)
        mv = tf.stack([m, v], axis=-1)

        pred = self.predictor(z[:, :-1])  # exclude transform parameter
        return pred, y, mv


if __name__ == '__main__':
    tf.random.set_seed(1234)

    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        strategy = tf.distribute.get_strategy()

    dataset = tf.keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = dataset.load_data()

    epochs = 50
    batch_size = 1000

    train_images = train_images / 255.0
    test_images = test_images / 255.0
    if len(train_images.shape) == 3:
        train_images = np.expand_dims(train_images, axis=-1)
        test_images = np.expand_dims(test_images, axis=-1)

    def prepare_model():
        def kl_d(y_ideal, y_pred):
            mu = y_pred[:, :, 0]
            var = y_pred[:, :, 1]
            dist_pred = tfp.distributions.Normal(mu, var + 1e-8)
            dist_standard = tfp.distributions.Normal(0, 1)
            loss = tfp.distributions.kl_divergence(dist_pred, dist_standard) / tf.cast(tf.size(mu), dtype=tf.float32)
            return loss

        vae = VAE()
        inputs = tf.keras.layers.Input(train_images[0].shape)
        x = tf.keras.layers.Flatten()(inputs)
        x = vae(x)
        model = tf.keras.Model(inputs=inputs, outputs=x)

        model.summary()
        model.compile(optimizer="adam",
                      loss=["sparse_categorical_crossentropy", "mse", kl_d],
                      metrics={"vae": ["accuracy"]},
                      loss_weights=[1e-3, 1, 5])  # heuristic
        return model, vae


    with strategy.scope():
        model, vae = prepare_model()

    model.fit(train_images, [train_labels, train_images, train_images], epochs=epochs, validation_split=0.02, batch_size=batch_size)
    model.evaluate(test_images, [test_labels, test_images, test_images])

    vae.decoder.apply_transform = False
    x = train_images[0:12]
    y = vae(tf.keras.layers.Flatten()(x))[1].numpy()

    fig, ax = plt.subplots(2, 1, tight_layout=True)
    x = np.transpose(x.reshape((-1, 2, 28, 28)), (1, 2, 0, 3)).reshape((56, -1))
    y = np.transpose(y.reshape((-1, 2, 28, 28)), (1, 2, 0, 3)).reshape((56, -1))

    ax[0].axis('off')
    ax[1].axis('off')

    ax[0].imshow(x)
    ax[1].imshow(y)
    plt.show()

若干のコード解説

デコーダのやっている事が若干特異かつ複雑なので説明します。
まずこの、apply_transformなところですが、pは単に潜在変数の一部を後ろから6個適当に切りとったものです。

これを3つ0足して、3x3の変形行列にreshapeして、uv(座標ベクトル)を3次元に拡張して掛けて2次元に取り直しています。

    p = z[:, -6:]  # transform params
    ...
    if self.apply_transform:
        p_mat = tf.reshape(p, (-1, 2, 3))
        p_zeros = tf.zeros_like(p_mat[:, 0:1, :])
        p_mat = tf.concat([p_mat, p_zeros], axis=-2)
        p_mat = 0.1 * p_mat + tf.eye(3)
        uv1 = tf.ones_like(uv[:, :, :, 0:1])
        uv3 = tf.concat([uv, uv1], axis=-1)
        uv = tf.keras.backend.batch_dot(uv3, p_mat)[:, :, :, 0:2]
        p_mat = 0.1 * p_mat + tf.eye(3)

ここで 0.1掛けたり単位行列足したりしてるのが気になるかもしれませんが、これは純粋に潜在変数をそのまま利用すると元の形からかけ離れた出力が生まれてしまうので、「何もしない(単位行列)」+「ちょっとした変形(0.1 * p_mat)」という意味で0.1を掛けています。

これは経験的な値です。

加えてもう一つわかりにくいかもしれないのが、ここですね。

        uvz = tf.concat([uv, z], axis=-1)
        uvz = tf.reshape(uvz, (-1, uvz.shape[-1]))  # (batch * h * w, z + 2)

        x = self.layer1(uvz)
        x = self.layer2(x)
        x = self.layer3(x)
        x = tf.reshape(x, (-1, uv.shape[1], uv.shape[2], 1))

ここでは、座標を「座標+潜在変数」というベクトルに置き換えて、デコーダの肝な3Denseに投げてます。入力は画像でなく「座標+潜在変数」ベクトルであり、出力は画像でなくスカラー値です。

つまり、(h, w, z + 2) という形をそのまま投げるかわりに、画素に相当する各「座標+潜在変数」ベクトルを、「バッチ次元に押し込めて」、(batch * h * w)個の「座標+潜在変数」ベクトルとして入力し、(batch * h * w)個の輝度値をスカラーとして取得し、(batch, h, w, 1) の画像に展開しなおしています。

この処理初めて見たとき、「これ考えた人天才か???」と本気で思いましたね……。一般的な処理なんですか?これ。少なくとも自分の発想には無かったし他で見たことないです。

この辺りの詳細はさっきの記事の人の下記の前提記事が詳しいです。

ニューラルネットで解像度に依存しない画像表現(MNIST編)

おわりに

ここまで、VAEすげぇという話でした。

というより、参考元の記事書いた人が凄いんじゃないかと疑ってます。

今回の記事は特に技術的解説とかでも無いですが、こういう事が出来るという発想を今まで持っておらず、純粋に驚いたので書いてみました。

もう一度これのすごいと思ったところを明確にしておくと、「変形処理を挟んで学習させてから変形を抜くと正規化できる」という形で「潜在変数を事前に完全制御している」ところでした。

例えばメガネを掛けている画像の潜在変数を抜き出して他のと合成して中間のを作る、とか、潜在変数弄ってみたらメガネが消えたのでこれはメガネの潜在変数、とかいう判断はよくある技術としてわかります。

しかし、事前に変形行列として扱ってほしいと設計するだけで、実際画像がどの程度回転等変形しているかのアノテーションすら与えなくても変形行列として完全に機能する、という点に驚いたのでした。

もしかして自分が知らないだけで一般的な技術なんですかね……?

StyleGanなんかは潜在変数をStyleとして扱う事でバリエーション豊かな画像を生成することができる点で似ていますが、あれってStyleを事前設計するようなことができるんでしょうか?

あるいは自分の無知を晒しただけかもしれませんが、誰かの何かのヒントになれば。

今回は以上です。

Discussion