続:Traxでカスタム活性化関数を使う
はじめに
この記事は以下の記事の続きです。
前回は、Googleが公開しているDeep Learningライブラリの一つであるTraxで、独自の活性化関数を使う方法を書きました。
ですが、その時は純粋関数しか扱いませんでした。
今回は、もう少し複雑な挙動をする活性化関数を実装します。
learnableなケース:ParametricRelu
TraxではParametricReluが実装されています。
ならば独自で実装する意味がなさそうに見えますが、実はTraxのParametricReluは、
@assert_shape('...->...') # The output and input shapes are the same.
def ParametricRelu(a=1.):
r"""Returns a layer that computes a ReLU function with the given slope.
.. math::
f(x) = \left\{ \begin{array}{cl}
0 & \text{if}\ x \leq 0, \\
ax & \text{otherwise}.
\end{array} \right.
Args:
a: Slope of line for positive inputs.
"""
return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x)))
と定義されています。
もちろんこれはLearnableではないです。
ではLearnableなParametricReluはどう実装すれば良いでしょうか。
ここで参考になるのは、次です。
class ThresholdedLinearUnit(base.Layer):
"""Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf ."""
def init_weights_and_state(self, input_signature):
"""Initializes this layer's single weight to zero."""
del input_signature
self.weights = jnp.zeros((), dtype=jnp.float32)
def forward(self, inputs):
"""Executes this layer as part of a forward pass through the model.
Args:
inputs: Tensor.
Returns:
Tensor of same shape and dtype as the input.
"""
threshold = self.weights
return jnp.maximum(inputs, threshold)
つまり、普通にLayerとして実装すればよさそうです。
というわけで実装してみたのが以下です。
@assert_shape("...->...")
class ParametricRelu(Layer):
def __init__(self, num_parameters: int = 1, a: float = 0.25):
super().__init__()
self._a = a
self._num_parameters = num_parameters
def init_weights_and_state(self, input_signature):
del input_signature
self.weights = self._a * jnp.ones((self._num_parameters,), dtype=jnp.float32)
def forward(self, inputs):
return jnp.where(inputs <= 0, self.weights * inputs, inputs)
ここでnum_parameters
はPyTorchのインターフェイスに合わせました。
大体見た通りですが、重要な点が一つあります。
それは、self._a
やself._num_parameters
のように、アンダーバーが変数についていることです。
これは必須です。
アンダーバーをつけずにself.a
などとすると動きません。
それが許されるのは、self.weights
などの特殊なものだけです。
訓練かどうかで挙動が変わるケース:RandomizedRelu
訓練時とそれ以外で挙動を変えたいケースがあります。
典型的にはDropoutです。
そのような場合、どうすれば良いでしょうか。
Dropoutは、forward関数内で、self._mode
が"train"かどうかで分岐しています。
def forward(self, x):
"""Executes this layer as part of a forward pass through the model.
Args:
x: Tensor of activations.
Returns:
Tensor of same shape and dtype as the input.
"""
if self._mode != 'train':
return x
なるほど、と思ってRandomizedReluを実装してみます。
@assert_shape("...->...")
class RandomizedRelu(Layer):
def __init__(self, lower=0.125, upper=0.3333333333333333, mode="train"):
super().__init__()
self._lower = lower
self._upper = upper
self._mode = mode
def forward(self, inputs):
if self._mode == "train":
a = fastmath.random.uniform(
self.rng, dtype=jnp.float32, minval=self._lower, maxval=self._upper
)
else:
a = (self._lower + self._upper) / 2
return jnp.where(inputs <= 0, a * inputs, inputs)
これで万事解決、として記事を終わらせたかったのですが、問題がありました。
このような実装の場合、これらを使う側で、modeの切り替えが必要です。
で、trainer_lib.py内ではmodeが自動的に切り替わっています。
model_train = model(mode='train')
model_predict_eval = model(mode='eval')
しかしながら、このtrainer_lib.pyにあるのは、古いAPIです。
新しいAPIはtraining.pyの方にあるのですが、こちらではmodeの切り替え箇所が見つかりませんでした。
近いところはあるんですが、記録するmetricsの切り替えにしか使っておらず、modelの挙動をTrainTaskとEvalTaskで自動で切り替えてはいなさそうでした。
まだ新しい方は"under development"とのことなので、そのうち実装されるのかもしれません。
単に私が見落としているだけかも知れませんので、その際はぜひコメントいただければと存じます。
おわりに
というわけで、Traxにおける少し複雑な活性化関数の実装方法について紹介しました。
Discussion