Open2
JAXでバックテスト
jax.jitする際には、(numbaと違って)分岐を含んでいるとコンパイルされない
なので
jax.lax.condとかjax.lax.switchを使う必要がある。
register_pytree_node_classを使えばclassをjitコンパイルする事ができる。
ただ、jitコンパイルしたクラスのattributeはprintデバッグできないので、そういうときはjitをdisableする
jax.disable_jit
jax.jitする際には、(numbaと違って)分岐を含んでいるとコンパイルされない
なので
jax.lax.condとかjax.lax.switchを使う必要がある。
register_pytree_node_classを使えばclassをjitコンパイルする事ができる。
ただ、jitコンパイルしたクラスのattributeはprintデバッグできないので、そういうときはjitをdisableする
jax.disable_jit