🥚

numpy<=>mlx(swift)チートシート(work in progress)

に公開

ここに書いてあった

numpy mlx 備考
a[-1:] a[.stride(from:-1)]
a[:-3] a[.stride(to:-3)]
a[:] a[0 ...]
a[...] a[.ellipsis]
a[:,1] a[0 ...,:1]
a[None] a[.newAxis]
a[None] a.expandedDimensions(axes: [0])
a[...,None] a[.ellipsis,.newAxis]
a[...,None] a.expandedDimensions(axes: [-1])
np.ones_like(a) MLXArray.ones(like:a)
np.expand_dims(a,1) a.expandedDimensions(axes: [1])
a @ b a.matmul(b)
a.T a.transposed(0,1)
a.permutate(0,2,1) a.transposed(1,2)
_,indices = a.sorted() let indices = a.argSort()
a[b < 3] a[ (b .< 3).enumerated().compactMap { (i, flag) in flag ? Int32(i) : nil } ] 条件のmaskを直接与えると落ちるのでindicesに変換をかける
a.copy() MLXArray(data:a.asData(access: .copy))

Discussion