🌊
Traxでカスタム活性化関数を使う
Traxとは
Traxとは、Googleが公開しているJAXベースの(ただしバックエンドにはtensorflow-numpyも選択できます)Deep Learningライブラリの一つです。
さらに言うと、Tensor2Tensorの後継です。
このFashion MNISTの分類ノートブックを見ると雰囲気が掴めると思います。
Traxでカスタム活性化関数を使う
Traxの活性化関数は、layersディレクトリの下にあります。
例えばLeakyReluは
@assert_shape('...->...') # The output and input shapes are the same.
def LeakyRelu(a=0.01):
r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs.
.. math::
f(x) = \left\{ \begin{array}{cl}
ax & \text{if}\ x \leq 0, \\
x & \text{otherwise}.
\end{array} \right.
Args:
a: Slope of line for negative inputs.
"""
return Fn('LeakyRelu', lambda x: jnp.where(x >= 0, x, a * x))
と定義されています。
なので、TanhShrinkが欲しければ
@assert_shape("...->...")
def TanhShrink():
return Fn("TanhShrink", lambda x: x - jnp.tanh(x))
とやれば良いです。
実際いくつか実装してみました。
現在以下が実装されてます。PyTorchにあってTrax本体になく、実装が簡単そうなものプラスアルファです。
- Celu
- HardShrink
- HardSwish
- LogSigmoid
- Mish
- Relu6
- Silu
- SoftShrink
- SoftSign
- TanhExp
- TanhShrink
Traxがインストールされていれば、pip install trax-extra-activation
で使えます。
おわりに
というわけで、Traxでカスタム活性化関数を使う方法を書きました。
が、実はちょっと片手落ちで、複雑な活性化関数には触れられませんでした。
複雑な、というのは、learnableだったり、randomnessがあるものです。
上で出てくるFn
というのは、PureLayer
を返す関数です。
で、PureLayer
というのは、その名にあるように、pure functionのためのLayerです。
なので、learnableであったり、randomnessがあるものには使えません。
これについてはまたそのうちに書こうと思います。
Discussion