Open4

Apple SiliconのGPUでJAXを高速化

Kai SugaharaKai Sugahara

CPUとMETAL(GPU)で計算時間を簡単に比較.
実行環境は
MacBook Air 2022, Apple M2, メモリ16GB, 256GB SSD.

CPU.ipynb
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
print(jax.devices())
# Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
# Metal device set to: Apple M2
# [CpuDevice(id=0)]

@jax.jit
def f(x):
    return x.dot(x)

row = 1000
X = jnp.arange(row**2, dtype=jnp.float32).reshape(row, row)

%%timeit
f(X)
# 4.49 ms ± 79.9 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
GPU.ipynb
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "METAL")
print(jax.devices())
# Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
# Metal device set to: Apple M2
# [METAL(id=0)]

@jax.jit
def f(x):
    return x.dot(x)

row = 1000
X = jnp.arange(row**2, dtype=jnp.float32).reshape(row, row)

%%timeit
f(X)
# 903 μs ± 1.24 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

5分の1くらいになっている.すごい