Open3
jax.numpy でミニバッチ処理を行うときのテク

jax.numpy はジャグ配列を扱えない
正方形な行列しか jnp.array
化できないし、jax.jit
での最適化、 jax.vmap
での並列化を行えない
なので、ジャグ配列は末尾に -1 を padding するなどして正方形な行列にして、その -1 を無視する処理を実装する必要がある
ここでは無視するテクを書いていく

numpy.average の weights を使用する
numpy.average の weights を指定すると加重平均を求めることができる
これを使用して -1 が対応する要素が0の weights を作れば良い

numpy.sum の where を使用する
where を指定すると sum するときに無視する要素を指定できる
import jax
import jax.numpy as jnp
xs = jnp.array(
[
[1, 2, -1],
[4, 5, 6],
]
)
def sum(x):
return jnp.sum(x, where=x != -1)
jax.vmap(sum)(xs)