Open4
Apple SiliconのGPUでJAXを高速化
を読んでいて,そういえばJAXはどうなっているのかな?と思い,調べたことをメモ.
結論から言うと
JAXはMETALに対応したプラグインが既に公開されていた.
jaxとjaxlibに加えて,jax-metalをインストールするだけでApple SiliconのGPUを活用可能.
jax, jaxlib, jax-metalのバージョンやMacOSの依存関係は
にまとまっている.CPUとMETAL(GPU)で計算時間を簡単に比較.
実行環境は
MacBook Air 2022, Apple M2, メモリ16GB, 256GB SSD.
CPU.ipynb
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
print(jax.devices())
# Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
# Metal device set to: Apple M2
# [CpuDevice(id=0)]
@jax.jit
def f(x):
return x.dot(x)
row = 1000
X = jnp.arange(row**2, dtype=jnp.float32).reshape(row, row)
%%timeit
f(X)
# 4.49 ms ± 79.9 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
GPU.ipynb
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "METAL")
print(jax.devices())
# Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
# Metal device set to: Apple M2
# [METAL(id=0)]
@jax.jit
def f(x):
return x.dot(x)
row = 1000
X = jnp.arange(row**2, dtype=jnp.float32).reshape(row, row)
%%timeit
f(X)
# 903 μs ± 1.24 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
5分の1くらいになっている.すごい
アクティビティモニタを見たら,きちんとGPU使われていた.