🦔

続:Traxでカスタム活性化関数を使う

4 min read

はじめに

この記事は以下の記事の続きです。

https://zenn.dev/catminusminus/articles/cf0c81857810f3

前回は、Googleが公開しているDeep Learningライブラリの一つであるTraxで、独自の活性化関数を使う方法を書きました。
ですが、その時は純粋関数しか扱いませんでした。
今回は、もう少し複雑な挙動をする活性化関数を実装します。

learnableなケース:ParametricRelu

TraxではParametricReluが実装されています。
ならば独自で実装する意味がなさそうに見えますが、実はTraxのParametricReluは、

@assert_shape('...->...')  # The output and input shapes are the same.
def ParametricRelu(a=1.):
  r"""Returns a layer that computes a ReLU function with the given slope.
  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0  & \text{if}\ x \leq 0, \\
          ax & \text{otherwise}.
      \end{array} \right.
  Args:
    a: Slope of line for positive inputs.
  """
  return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x)))

定義されています

もちろんこれはLearnableではないです。
ではLearnableなParametricReluはどう実装すれば良いでしょうか。

ここで参考になるのは、次です

class ThresholdedLinearUnit(base.Layer):
  """Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf ."""

  def init_weights_and_state(self, input_signature):
    """Initializes this layer's single weight to zero."""
    del input_signature
    self.weights = jnp.zeros((), dtype=jnp.float32)

  def forward(self, inputs):
    """Executes this layer as part of a forward pass through the model.
    Args:
      inputs: Tensor.
    Returns:
      Tensor of same shape and dtype as the input.
    """
    threshold = self.weights
    return jnp.maximum(inputs, threshold)

つまり、普通にLayerとして実装すればよさそうです。
というわけで実装してみたのが以下です。

@assert_shape("...->...")
class ParametricRelu(Layer):
    def __init__(self, num_parameters: int = 1, a: float = 0.25):
        super().__init__()
        self._a = a
        self._num_parameters = num_parameters

    def init_weights_and_state(self, input_signature):
        del input_signature
        self.weights = self._a * jnp.ones((self._num_parameters,), dtype=jnp.float32)

    def forward(self, inputs):
        return jnp.where(inputs <= 0, self.weights * inputs, inputs)

ここでnum_parametersPyTorchのインターフェイスに合わせました。

大体見た通りですが、重要な点が一つあります。
それは、self._aself._num_parametersのように、アンダーバーが変数についていることです。
これは必須です。
アンダーバーをつけずにself.aなどとすると動きません。
それが許されるのは、self.weightsなどの特殊なものだけです。

訓練かどうかで挙動が変わるケース:RandomizedRelu

訓練時とそれ以外で挙動を変えたいケースがあります。
典型的にはDropoutです。
そのような場合、どうすれば良いでしょうか。

Dropoutは、forward関数内で、self._modeが"train"かどうかで分岐しています

  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.
    Args:
      x: Tensor of activations.
    Returns:
      Tensor of same shape and dtype as the input.
    """
    if self._mode != 'train':
      return x

なるほど、と思ってRandomizedReluを実装してみます。

@assert_shape("...->...")
class RandomizedRelu(Layer):
    def __init__(self, lower=0.125, upper=0.3333333333333333, mode="train"):
        super().__init__()
        self._lower = lower
        self._upper = upper
        self._mode = mode

    def forward(self, inputs):
        if self._mode == "train":
            a = fastmath.random.uniform(
                self.rng, dtype=jnp.float32, minval=self._lower, maxval=self._upper
            )
        else:
            a = (self._lower + self._upper) / 2
        return jnp.where(inputs <= 0, a * inputs, inputs)

これで万事解決、として記事を終わらせたかったのですが、問題がありました。

このような実装の場合、これらを使う側で、modeの切り替えが必要です。
で、trainer_lib.py内ではmodeが自動的に切り替わっています

    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')

しかしながら、このtrainer_lib.pyにあるのは、古いAPIです。
新しいAPIはtraining.pyの方にあるのですが、こちらではmodeの切り替え箇所が見つかりませんでした。
近いところはあるんですが、記録するmetricsの切り替えにしか使っておらず、modelの挙動をTrainTaskとEvalTaskで自動で切り替えてはいなさそうでした。

まだ新しい方は"under development"とのことなので、そのうち実装されるのかもしれません。
単に私が見落としているだけかも知れませんので、その際はぜひコメントいただければと存じます。

おわりに

というわけで、Traxにおける少し複雑な活性化関数の実装方法について紹介しました。