Open3

jax.numpy でミニバッチ処理を行うときのテク

odanodan

jax.numpy はジャグ配列を扱えない
正方形な行列しか jnp.array 化できないし、jax.jit での最適化、 jax.vmap での並列化を行えない
なので、ジャグ配列は末尾に -1 を padding するなどして正方形な行列にして、その -1 を無視する処理を実装する必要がある
ここでは無視するテクを書いていく