📘

TensorFlow の使い方練習3:オプティマイザのカスタマイズ

2023/04/02に公開

はじめに

この記事は以下の記事の続きです。
https://zenn.dev/wsuzume/articles/cb1511666e2f99

今回はオプティマイザをカスタマイズします。大抵は Adam とか既に実装されてるやつを使っておけばいいのですが Riemann 多様体上の最適化、有名どころだと双曲埋め込みとかに使いたい場合はオプティマイザをカスタムする必要があります。そんな必要や欲求に駆られる変態は地球上にそんなにいないと思いますが。

今回作成したノートブックはこちら。

オプティマイザのカスタマイズ方法

公式ドキュメントの『Creating a custom optimizer』のところにちょっとだけ記述がある。逆に言えばこれしか情報がないのであとはソースコードから読み解くしかない。

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#creating_a_custom_optimizer_2

  1. tf.keras.optimizers.Optimizer クラスを継承します
  2. build メソッドを実装してオプティマイザの状態を保持する変数を定義します
  3. update_step メソッドでアップデートのロジックを記述します
  4. get_config メソッドですべてのハイパラを含むオプティマイザの情報をシリアライズします

というわけで割と簡単にできそう。あとはソース読めって感じな気がするので一番簡単なはずの tf.keras.optimizers.SGD[1] のソースを見に行こう。本質的な部分の build, update_step, get_config は 60 行くらいだろうか?

https://github.com/keras-team/keras/blob/v2.12.0/keras/optimizers/sgd.py#L26-L199

準備

いきなり複雑なオプティマイザを作るような図に乗った人間は地獄に堕ちる。どういう地獄かというと、大抵は収束しないしそれが

  1. 理論的に収束するかわからない
  2. 数式が合ってるかわからない
  3. 最適化対象のモデルの実装が合ってるかわからない
  4. その目的関数を最小化してうまくいくのかわからない
  5. オプティマイザのコードが合ってるかわからない
  6. ハイパーパラメータが合ってるかわからない
  7. 与えたデータが不適切かもわからない
  8. 浮動小数点演算に起因する数値誤差かもわからない
  9. たとえうまくいっているように見えても本当にうまくいっているのかわからない

という不明地獄である。

私は二度とその地獄を味わいたくないのでもっとも簡単な SGD から実装する。それもモデルはもっとも簡単で理論的に収束が保証されていて数式もコードも間違えようがない線形回帰モデルで二乗誤差を目的関数とし、絶対に回帰できる線形なデータを与えて浮動小数点演算に起因する数値誤差が出ないくらい小さいモデルにして TensorFlow の既存のオプティマイザでちゃんと収束することを確認した上で、だ。

再現性を取るためのシード固定

なぜか知らないが tf.random.set_seed だけでなく他のいろいろなシードも固定しないと結果が毎回変わりがち。

https://stackoverflow.com/questions/36288235/how-to-get-stable-results-with-tensorflow-setting-random-seed

import os
import random
import numpy as np
import tensorflow as tf

def reset_random_seeds(seed):
  os.environ['PYTHONHASHSEED'] = str(seed)
  random.seed(seed)
  np.random.seed(seed)
  tf.random.set_seed(seed)

reset_random_seeds(200)

擬似データの生成

X = np.random.normal(0, 1, (1000, 64))
w = np.random.normal(0, 10, (64, 2))

y = X.dot(w) + np.random.normal(0, 0.05, (1000, 2))

データセットのリピート設定

データセット中に 1000 個のデータがあるとして、batch_size を 32 とかに設定しておくと 32 ステップ目でデータを使い切ってエラーを吐いてしまうので、データセットを繰り返し使えるようにリピート設定をしないといけない。

dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.repeat().batch(32)

モデルの定義

線形回帰モデル。

inputs = keras.Input(shape=64)
outputs = layers.Dense(2)(inputs)

model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    loss=keras.losses.MeanSquaredError(),
    optimizer=keras.optimizers.SGD(learning_rate=0.1),
    metrics=keras.metrics.MeanSquaredError()
)

学習と結果の確認

result = model.fit(dataset, epochs=20, steps_per_epoch=100)
from matplotlib import pyplot as plt

# Loss のプロット
plt.figure(figsize=(8, 6))
plt.title('Mean Squared Error')
plt.xlabel('step')
plt.ylabel('loss')
plt.plot(result.history['loss'])
plt.yscale('log')
plt.grid()
plt.show()

y_pred = model(X)

print(y)
print(y_pred)
[[101.69491404  31.64441866]
 [-64.71029639  25.06918281]
 [-85.26896897  49.34619852]
 ...
 [-97.16003223 -33.47706224]
 [-88.16147633  68.83922654]
 [-19.45862791 -98.42042841]]
tf.Tensor(
[[101.71035   31.658903]
 [-64.59336   25.111202]
 [-85.33183   49.378597]
 ...
 [-97.19416  -33.51119 ]
 [-88.12234   68.76358 ]
 [-19.401066 -98.43188 ]], shape=(1000, 2), dtype=float32)

ロスが減少していて出力がおおよそ一致しているので OK。

SGD のコードを解読する

tf.keras.optimizers.SGD はどうも引数や内部に momentum があるので Momentum SGD と一体化されている。デフォルトが momentum=0.0 なのでこれを指定しなければ単なる SGD となる。はず。

確認が必要なのは SGD クラスに実装されている __init__, build, update_step, get_config の4つである。順番は前後するが以下が解読結果である。

__init__

__init__ の部分は長い割に

  • 自分に特有のメンバは自身に追加する
  • それ以外は親クラスの __init__ に渡す

という処理をやってるだけである。

ここで使われている Optimizer._build_learning_rate() はソースを読みに行くと learning_rate を TensorFlow の Variable に変換して返してくれるようである。

https://github.com/keras-team/keras/blob/f9336cc5114b4a9429a242deb264b707379646b7/keras/optimizers/optimizer.py#L377-L399

get_config

    def get_config(self):
        config = super().get_config()

        config.update(
            {
                "learning_rate": self._serialize_hyperparameter(
                    self._learning_rate
                ),
                "momentum": self.momentum,
                "nesterov": self.nesterov,
            }
        )
        return config

親クラスの get_config メソッドで取得した config(おそらく辞書)に対して、自身が持つ特有のハイパーパラメータを追加している。config の型を確認しておこう。

optimizer = keras.optimizers.SGD()
print(type(optimizer.get_config()))
output
<class 'dict'>

やっぱ辞書っぽい。

update_step

    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        lr = tf.cast(self.learning_rate, variable.dtype)
        m = None
        var_key = self._var_key(variable)
        momentum = tf.cast(self.momentum, variable.dtype)
        m = self.momentums[self._index_dict[var_key]]

        # TODO(b/204321487): Add nesterov acceleration.
        if isinstance(gradient, tf.IndexedSlices):
            # Sparse gradients.
            add_value = tf.IndexedSlices(
                -gradient.values * lr, gradient.indices
            )
            if m is not None:
                m.assign(m * momentum)
                m.scatter_add(add_value)
                if self.nesterov:
                    variable.scatter_add(add_value)
                    variable.assign_add(m * momentum)
                else:
                    variable.assign_add(m)
            else:
                variable.scatter_add(add_value)
        else:
            # Dense gradients
            if m is not None:
                m.assign(-gradient * lr + m * momentum)
                if self.nesterov:
                    variable.assign_add(-gradient * lr + m * momentum)
                else:
                    variable.assign_add(m)
            else:
                variable.assign_add(-gradient * lr)

よくわからない部分もあるが勾配が疎行列か密行列かで場合分けされていて、mNone でなければモーメンタムをわちゃわちゃしていることがわかる。つまり疎行列への対応とモーメンタムを除いてしまえば update_step は以下のコードまで簡略化できる。

    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        lr = tf.cast(self.learning_rate, variable.dtype)
        variable.assign_add(-gradient * lr)

これで単純な SGD くらいは実装できるようになったと思う。

tf.IndexedSlices はスパースなテンソルに対して用いるメモリ効率のよい表現のようである。

build

一番のクセモノがこいつ。

    def build(self, var_list):
        """Initialize optimizer variables.

        SGD optimizer has one variable `momentums`, only set if `self.momentum`
        is not 0.

        Args:
          var_list: list of model variables to build SGD variables on.
        """
        super().build(var_list)
        if hasattr(self, "_built") and self._built:
            return
        self.momentums = []
        for var in var_list:
            self.momentums.append(
                self.add_variable_from_reference(
                    model_variable=var, variable_name="m"
                )
            )
        self._built = True

これもやってることは難しくない。まず親クラスの build メソッドをそのまま var_list を引数に呼び出している。説明のところに「Initialize optimizer variables.」と書いてあるので、オプティマイザが内部的に保持する変数があるならここで初期化しておけということだろう。

        if hasattr(self, "_built") and self._built:
            return
	 
	...
	
        self._built = True

この部分は複数回 build が実行されたとしても ... の部分が1回だけしか実行されないようにするためのロックである。親クラスの build がロックの外側で実行されているのが若干気になるが、この事実から

  • 似たようなロックが親クラス内にもある
  • model.fit の度に build が呼び出されており、途中から計算を再開するにはモーメンタムが初期化されてしまっては困る

のどちらかまたは両方の可能性がある。あとで検証しよう。

var_list がどういった形式で何が与えられているのか、親クラスで実装されているであろう add_variable_from_reference がどういった挙動をするのかが不明なのでここも確認対象である。

似たようなロックが親クラス内にもある?

ある。_built フラグを True にしておくと親クラスの build もスキップされる。

https://github.com/keras-team/keras/blob/f9336cc5114b4a9429a242deb264b707379646b7/keras/optimizers/optimizer.py#L402-L414

一方で親クラス内には _built フラグを True にするようなコードは見当たらない。よって

  • 親クラスは何回 build しても大丈夫なようにできているはず
    • 子クラスの build 内では super().build() を実行しろと公式ドキュメントに書いてある一方、_built フラグについては触れられていないので、親クラスの build は仕様上繰り返し実行される可能性がある。
  • 一度ビルドしたら二度と build してほしくないときは _built フラグを設定してもよい

ということだろう。つまり子クラスで _built フラグを立てておけば事実上 build は最初に必要になったときに一度だけ呼び出される

build はいつ呼び出されるのか?

ソースコードを読みに行って

  • Model.compile():不確定
    • 関係ありそうな Model.compile_from_config() からは呼び出されている
  • Model.fit():確定
    • Optimizer.minimize() の中の Optimizer.apply_gradients() から呼び出されている
    • ただしどういう条件で呼び出しがスキップされるかは不明

というところまでは特定している。他のメソッドから呼ばれる可能性があるかどうかは定かではない。StackOverflow に質問を投稿したのでそのうち回答があるかもしれない。

https://stackoverflow.com/questions/75848619/when-and-where-is-tensorflows-custom-optimizer-build-method-called

あとで作る自作クラスの build メソッドの中に print を書いてタイミングを確かめてみると、model.compile のときは実行されず、model.fit のときに呼び出されているようだった。

var_list はどういった形式で何が与えられているのか?

モデルに含まれるすべての変数がタプルで渡されるようである。あとで作る自作クラスの build メソッドの中に print(type(var_list)) を書いて確かめてみるとよい。

add_variable_from_reference はどういった挙動をするのか?

https://github.com/keras-team/keras/blob/f9336cc5114b4a9429a242deb264b707379646b7/keras/optimizers/optimizer.py#L471-L519

呼び出し側のコード
        for var in var_list:
            self.momentums.append(
                self.add_variable_from_reference(
                    model_variable=var, variable_name="m"
                )
            )
呼び出される側の主なコード
                else:
                    # We cannot always use `zeros_like`, because some cases
                    # the shape exists while values don't.
                    initial_value = tf.zeros(
                        model_variable.shape, dtype=model_variable.dtype
                    )
	
	...
        
	variable = tf.Variable(
            initial_value=initial_value,
            name=f"{variable_name}/{model_variable._shared_name}",
            dtype=model_variable.dtype,
            trainable=False,
        )
        self._variables.append(variable)

第一引数として与えた変数と同じ shape の 0 埋めされたテンソルをオプティマイザの _variables メンバに f"{variable_name}/{model_variable._shared_name}" という名前で追加するような挙動をする。

if model_variable.shape.rank is None: という部分があるが、rank はテンソルの shape が可変なときに None になるらしい。テンソルの形状が固定のときはこの部分は気にしなくていい。

SGD を作ろう

以上で tf.keras.optimizers.SGD の中で何が行われているかおおよそ把握できた。最小構成で構築し直すと以下のようになる。

class MySGD(tf.keras.optimizers.Optimizer):
    def __init__(self,
        learning_rate=0.01,
        weight_decay=None,
        clipnorm=None,
        clipvalue=None,
        global_clipnorm=None,
        use_ema=False,
        ema_momentum=0.99,
        ema_overwrite_frequency=None,
        jit_compile=True,
        name="SGD",
        **kwargs
    ):
        super().__init__(
            name=name,
            weight_decay=weight_decay,
            clipnorm=clipnorm,
            clipvalue=clipvalue,
            global_clipnorm=global_clipnorm,
            use_ema=use_ema,
            ema_momentum=ema_momentum,
            ema_overwrite_frequency=ema_overwrite_frequency,
            jit_compile=jit_compile,
            **kwargs
        )
        self._learning_rate = self._build_learning_rate(learning_rate)
    
    def build(self, var_list):
        """Initialize optimizer variables.

        SGD optimizer has one variable `momentums`, only set if `self.momentum`
        is not 0.

        Args:
          var_list: list of model variables to build SGD variables on.
        """
        # var_list がどうなっているか確かめたければこの辺りに print を仕込むなどする
        super().build(var_list)
        if hasattr(self, "_built") and self._built:
            return
        self._built = True
    
    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        lr = tf.cast(self.learning_rate, variable.dtype)
        variable.assign_add(-gradient * lr)
    
    def get_config(self):
        config = super().get_config()

        config.update(
            {
                "learning_rate": self._serialize_hyperparameter(
                    self._learning_rate
                ),
            }
        )
        return config

このように作成したオプティマイザは keras で提供されているオプティマイザと同じように使用できる。

inputs = keras.Input(shape=64)
outputs = layers.Dense(2)(inputs)

model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    loss=keras.losses.MeanSquaredError(),
    optimizer=MySGD(learning_rate=0.1),
    metrics=keras.metrics.MeanSquaredError()
)
result = model.fit(dataset, epochs=20, steps_per_epoch=10)
from matplotlib import pyplot as plt

# Loss のプロット
plt.figure(figsize=(8, 6))
plt.title('Mean Squared Error')
plt.xlabel('step')
plt.ylabel('loss')
plt.plot(result.history['loss'])
plt.yscale('log')
plt.grid()
plt.show()

tf.keras.optimizers.SGD で使っていた部分だけ抜き出しているので結果はほぼ一致する。

おしまい

まだ分からない部分(_variables に保存した補助変数の中からどうやって狙ったやつを取得するのかとか)はあるが、基本は読み解けたのでモーメンタムの部分とか Adam とかのコードとか読めば補完できるだろう。というわけで今回はこれでおしまいです。

脚注
  1. 厳密には tf.keras.optimizers.experimental.SGD だが、tf.keras.optimizers.SGD として直接インポートもできる。 ↩︎

Discussion