🌊

Traxでカスタム活性化関数を使う

2021/06/26に公開

Traxとは

Traxとは、Googleが公開しているJAXベースの(ただしバックエンドにはtensorflow-numpyも選択できます)Deep Learningライブラリの一つです。
さらに言うと、Tensor2Tensorの後継です。

このFashion MNISTの分類ノートブックを見ると雰囲気が掴めると思います。

Traxでカスタム活性化関数を使う

Traxの活性化関数は、layersディレクトリの下にあります。

例えばLeakyReluは

@assert_shape('...->...')  # The output and input shapes are the same.
def LeakyRelu(a=0.01):
  r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs.
  .. math::
      f(x) = \left\{ \begin{array}{cl}
          ax & \text{if}\ x \leq 0, \\
          x  & \text{otherwise}.
      \end{array} \right.
  Args:
    a: Slope of line for negative inputs.
  """
  return Fn('LeakyRelu', lambda x: jnp.where(x >= 0, x, a * x))

定義されています
なので、TanhShrinkが欲しければ

@assert_shape("...->...")
def TanhShrink():
    return Fn("TanhShrink", lambda x: x - jnp.tanh(x))

とやれば良いです。
実際いくつか実装してみました。

https://github.com/Catminusminus/trax-extra-activation

現在以下が実装されてます。PyTorchにあってTrax本体になく、実装が簡単そうなものプラスアルファです。

  • Celu
  • HardShrink
  • HardSwish
  • LogSigmoid
  • Mish
  • Relu6
  • Silu
  • SoftShrink
  • SoftSign
  • TanhExp
  • TanhShrink

Traxがインストールされていれば、pip install trax-extra-activationで使えます。

おわりに

というわけで、Traxでカスタム活性化関数を使う方法を書きました。
が、実はちょっと片手落ちで、複雑な活性化関数には触れられませんでした。
複雑な、というのは、learnableだったり、randomnessがあるものです。

上で出てくるFnというのは、PureLayerを返す関数です。
で、PureLayerというのは、その名にあるように、pure functionのためのLayerです。
なので、learnableであったり、randomnessがあるものには使えません。

これについてはまたそのうちに書こうと思います。

Discussion