JAX入門~高速なNumPyとして使いこなすためのチュートリアル~
TensorFlow Advent Calendar 2020 10日目の記事です。空いてたので当日飛び入りで参加しました。
この記事では、TensorFlowの関連ライブラリである「JAX」について初歩的な使い方、ハマりどころ、GPU・TPUでの使い方や、画像処理への応用について解説します。
JAXとは
Google製のライブラリで、AutogradとXLAからなる、機械学習のための数値計算ライブラリ。簡単に言うと「自動微分に特化した、GPUやTPUに対応した高速なNumPy」。NumPyとほとんど同じ感覚で書くことができます。自動微分については解説が多いので、この記事では単なる高速なNumPyの部分を中心に書いていきます。
関連記事
- JAX Quickstart
- JAXで始めるディープラーニング
- JAX : Tutorials : JAX クイックスタート
- jaxのautogradをpytorchのautogradと比較、単回帰まで(速度比較追加)
GPU対応のNumPyという観点では、似たライブラリとしてPFN製のCuPyや、AnacondaがスポンサーとなっているNambdaもあります。
配列の初期化
最初はCPUに限定して書きます。JAXの導入はとてもシンプルで、あたかもNumPyのように使うことができます。
import jax.numpy as jnp
# NumPyではnp.arange(25, dtype=np.float32).reshape(5, 5)
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
print(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[ 0. 1. 2. 3. 4.]
[ 5. 6. 7. 8. 9.]
[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]
[20. 21. 22. 23. 24.]]
JAXでのNumPy関数
.block_until_ready()
NumPy関数はnp
をjnp
に書き換えるだけ。ただし、JAXでは非同期処理で計算されるため、計算の最後に.block_until_ready()
を追加します。
# NumPyではnp.dot(x, x.T)
x_gram = jnp.dot(x, x.T).block_until_ready()
print(x_gram)
[[ 30. 80. 130. 180. 230.]
[ 80. 255. 430. 605. 780.]
[ 130. 430. 730. 1030. 1330.]
[ 180. 605. 1030. 1455. 1880.]
[ 230. 780. 1330. 1880. 2430.]]
特に理由がなければ.block_until_ready()
はJAXの計算の最後のみ入れればOKです。
y = x + 1
x_gram = jnp.dot(x, y.T).block_until_ready() # 最後だけブロッキングを入れればOK
パフォーマンス・チューニング
このまま使ってもJAX本来の性能を引き出せないので、jit
でXLAコンパイルします。メソッドを@jit
とデコレーターで囲むか、jit
でメソッド全体をラップします。.block_until_ready()
はjit
の外側に出します。
@jit
デコレーターでfrom jax import jit
@jit
def static_jax_dot():
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
x_gram = jnp.dot(x, x.T)
return x_gram
static_jax_dot().block_until_ready()
DeviceArray([[ 30., 80., 130., 180., 230.],
[ 80., 255., 430., 605., 780.],
[ 130., 430., 730., 1030., 1330.],
[ 180., 605., 1030., 1455., 1880.],
[ 230., 780., 1330., 1880., 2430.]], dtype=float32)
メソッドをjitでラップする
メソッド全体をラップする書き方は次の通りです。関数をjitで囲んで呼び出すため、()が2回出てきます。Kerasのレイヤーの書き方に似ていますね。
def static_jax_dot_nojit():
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
x_gram = jnp.dot(x, x.T)
return x_gram
jit(static_jax_dot)().block_until_ready() # jitで関数をラップして呼び出す
ダメな例:.block_until_ready()をjitの内側に入れる
jitの内側に.block_until_ready()
を入れてはいけません。エラーになります。
# ダメな例
@jit
def static_jax_dot_badexample():
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
x_gram = jnp.dot(x, x.T)
return x_gram.block_until_ready()
static_jax_dot_badexample()
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-34-d4f187296547> in <module>()
7 return x_gram.block_until_ready()
8
----> 9 static_jax_dot_badexample()
13 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in __getattr__(self, name)
529
530 try:
--> 531 attr = getattr(self.aval, name)
532 except KeyError as err:
533 raise AttributeError(
AttributeError: 'ShapedArray' object has no attribute 'block_until_ready'
jitでラップした関数に引数を渡すときは注意
jitでラップした関数に、引数を渡す際には注意が必要です。先程のstatic_jax_dot()
内のx
サイズを引数で可変にするようなケースです。
ダメな例:何も考えずにラップされた関数に引数を渡す
何も考えずに引数を持った関数をjitで囲むと次のようなエラーになります。
# ダメな例
@jit
def variable_jax_dot_badexample(size):
x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
x_gram = jnp.dot(x, x.T)
return x_gram
variable_jax_dot_badexample(5)
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-36-1aa82304f1fa> in <module>()
6 return x_gram
7
----> 8 variable_jax_dot_badexample(5)
15 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in raise_concretization_error(val, context)
881 "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
882 f"Encountered tracer value: {val}")
--> 883 raise ConcretizationTypeError(msg)
884
885
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function variable_jax_dot_badexample at <ipython-input-36-1aa82304f1fa>:2, this concrete value was not available in Python because it depends on the value of the arguments to variable_jax_dot_badexample at <ipython-input-36-1aa82304f1fa>:2 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
「jnp.arange(...)
の引数でトレース値を検出しました。これは非常に計算コストが高いからやめてくれ」という趣旨です。トレーシングについてのエラーはTensorFlowで見たことある方もいると思います。同じ理由です。詳細は下記のドキュメントにあります。
良い例1:static_argnumsの引数を使ってjitでラップ
jitでラップするケース。static_argnums
というキーワード引数を指定します。
# 良い例1
def variable_jax_dot(size):
x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
x_gram = jnp.dot(x, x.T)
return x_gram
jit(variable_jax_dot, static_argnums=(0,))(5).block_until_ready()
キーワード引数static_argnums
の意味は、ラップする関数(variable_jax_dot
)の引数の何番目に固定値の引数があるかということです。size
はインデックス0の引数なので、ここでは(0, )
としています。引数が複数の場合次のようになります。
def multiple_inputs(x, size, axis):
y = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
z = jnp.dot(jnp.sin(x), y)
return jnp.stack([y, z], axis=axis)
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
jit(multiple_inputs, static_argnums=(1, 2))(x, x.shape[0], 0)
DeviceArray([[[ 0. , 1. , 2. , 3. ,
4. ],
[ 5. , 6. , 7. , 8. ,
9. ],
[ 10. , 11. , 12. , 13. ,
14. ],
[ 15. , 16. , 17. , 18. ,
19. ],
[ 20. , 21. , 22. , 23. ,
24. ]],
[[ 0.28107762, 1.4161646 , 2.55125 , 3.686337 ,
4.821422 ],
[ 28.255533 , 29.075655 , 29.895779 , 30.715904 ,
31.536028 ],
[ 15.748972 , 15.079163 , 14.409353 , 13.739543 ,
13.069733 ],
[-19.320755 , -20.52088 , -21.721 , -22.921124 ,
-24.121246 ],
[-26.71011 , -26.721159 , -26.732206 , -26.743258 ,
-26.754307 ]]], dtype=float32)
良い例2:partialでjitをデコレートする
デコレーターの場合は、partial
というデコレーターを使います。
参考:jit decorator can't accept arguments #184
# 良い例2
from jax import partial
@partial(jit, static_argnums=(0,))
def variable_jax_dot_deco(size):
x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
x_gram = jnp.dot(x, x.T)
return x_gram
variable_jax_dot_deco(5).block_until_ready()
デコレーター@partial
でjit
関数をラップし、static_argnums
を指定するのがポイント。指定方法は先程と同じ。
パフォーマンス比較~CPU~
CPU環境で、
- ただのNumPy
- XLAコンパイルしないJAX
- XLAコンパイルしたJAX
を配列のサイズを変えて比較してみます。環境はColabのCPUインスタンスを用いました。ドット積はサイズを大きくすると計算が重かったので、要素間の計算で試してみます。
- NumPy v1.18.5
- JAX v0.2.6
import numpy as np
import jax.numpy as jnp
from jax import jit, partial
# (size, size)の行列を作ってMod計算
@partial(jit, static_argnums=(0,))
def jax_jit_mod(size):
x = jnp.arange(size, dtype=jnp.int32)
mat = x[None, :] * x[:, None] # (size, size)
return mat % 256
def jax_nojit_mod(size):
x = jnp.arange(size, dtype=jnp.int32)
mat = x[None, :] * x[:, None]
return mat % 256
def numpy_mod(size):
x = np.arange(size, dtype=np.int32)
mat = x[None, :] * x[:, None]
return mat % 256
for i in range(4):
size = 10**(i+1)
repeat = 10**(4-i)
print("size =", size, "repeat =", repeat)
%timeit -n {repeat} numpy_mod(size)
%timeit -n {repeat} jax_nojit_mod(size).block_until_ready() # jitなしJAX
%timeit -n {repeat} jax_jit_mod(size).block_until_ready() # jitありJAX
結果は次のようになりました。
サイズ | 反復回数 | NumPy | jitなしJAX | jitありJAX | NumPy÷jitあり |
---|---|---|---|---|---|
10 | 10000 | 6.07 µs | 2.25 ms | 63 µs | 0.10 |
100 | 1000 | 67.9 µs | 2.87 ms | 138 µs | 0.49 |
1000 | 100 | 6.33 ms | 9.41 ms | 781 µs | 8.10 |
10000 | 10 | 517 ms | 606 ms | 48 ms | 10.77 |
配列のサイズが100まではNumPyが高速でしたが、1000以降は「jitありJAX」が圧勝しました。このケースでは「jitなしJAX」を使う意味がありませんでした。「NumPy÷jitあり」はNumPyの処理時間をjitありJAXの処理時間で割ったもので、この値が大きいほどJAXが有利です。巨大な配列に対して高速に計算したいのならJAXがとてもおすすめということになります。画像や動画処理で効いてきそうですね。
パフォーマンス比較~GPU~
JAXはGPUでも使えます。GPU環境では特にコードを追加する必要がなく、デフォルトでGPUを使ってくれます。
同じ比較をしてみました。
GPUでint32の場合
- 環境:ColabのGPU、ランタイム割当はTesla P-100
サイズ | 反復回数 | NumPy | jitありJAX | NumPy÷jitあり |
---|---|---|---|---|
10 | 10000 | 11.6 µs | 114 µs | 0.10 |
100 | 1000 | 69 µs | 109 µs | 0.63 |
1000 | 100 | 5.62 ms | 118 µs | 47.63 |
10000 | 10 | 486 ms | 738 µs | 658.54 |
傾向はCPUと変わりませんが、サイズが1万のときは桁違いに速くNumPyの658倍の速度を叩き出しました。サイズが小さいケースで結果が安定しないのは、メモリコピーなど計算以外のボトルネックがあるからだと思われます。ディープラーニングでも同じことは起こります。
ディープラーニングのフレームワークが浮動小数点数中心なせいか、JAXは特に整数計算が強いような印象があります。float32でも確かめてみます。
GPUでfloat32の場合
サイズ | 反復回数 | NumPy | jitありJAX | NumPy÷jitあり |
---|---|---|---|---|
10 | 10000 | 7.21 µs | 112 µs | 0.06 |
100 | 1000 | 115 µs | 119 µs | 0.97 |
1000 | 100 | 10.1 ms | 179 µs | 56.42 |
10000 | 10 | 939 ms | 5.12 ms | 183.40 |
float32でも強いことは強いです。サイズ1万で183倍で、int32の658倍ほどは良くならなかったです。GPUのスペックによっても変わるでしょう。
デバイスを明示して実行する
これは実験的な機能で変更される可能性がありますが、JITはデバイスを明示して計算可能です。GPUが利用可能ならデフォルトでGPUを使いますが、CPUで実行することもできます。
import jax
def dot_function():
x = jnp.arange(1000**2, dtype=jnp.float32).reshape(1000, 1000)
return jnp.dot(x, x.T)
# devicesは実験的機能で変更される可能性がある
%timeit -n 100 jit(dot_function, device=jax.devices("cpu")[0])().block_until_ready()
# デフォルト。deviceを指定しない場合と一緒
%timeit -n 100 jit(dot_function, device=jax.devices("gpu")[0])().block_until_ready()
# 100 loops, best of 3: 33.4 ms per loop
# 100 loops, best of 3: 618 µs per loop
CPUを明示した場合は明らかに遅くなっているのが確認できます。
- https://jax.readthedocs.io/en/latest/jax.html#jax.jit
- https://jax.readthedocs.io/en/latest/jax.html#jax.devices
Colab TPUでのJAX
JAXはTPUでも利用可能です。Colab TPUのでは、以下のコードを事前に追加する必要があります。
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
このコードは、JAX Cloud TPU Previewにある、Lorentz ODE Solverのコードからコピペしたものです。
TPUのドライバーが少し古かったり「ん?」と思うことはありますが、一応これで動きます。今後の仕様変更によっては変わる可能性が大いにあります。
パフォーマンス比較~TPU~
サイズ | 反復回数 | NumPy | jitありJAX | NumPy÷jitあり |
---|---|---|---|---|
10 | 10000 | 5.21 µs | 1.83 ms | 0.00 |
100 | 1000 | 97.2 µs | 1.76 ms | 0.06 |
1000 | 100 | 9.1 ms | 1.89 ms | 4.81 |
10000 | 10 | 780 ms | 4.39 ms | 177.68 |
int32型での比較です。GPUに比べて少し立ち上がりが遅いように見えますが、TPUの場合はコンパイル等最初のボトルネックが大きかったりするのでこれは仕方ないでしょう。
ただし、これはおそらくTPUの1コア分の計算でTPU本来の性能ではないと思われます。コア単位で分散計算させればもっと速度は出るはずです。詳細は省略しますが、pmapを活用すると複数のTPUコアに分散させられるとのことです。詳しくは以下の資料を見てください。
NumPy配列と併用する際は、device_putを使うべし
OpenCVで読み込んだ画像など、NumPy配列とJAXを併用するときは、device_put
でNumPy配列をデバイスに転送しておくと高速に動作します。
一例として、銀河の画像を用います。これはWikipediaにあるハッブル宇宙望遠鏡が撮影した銀河の画像ですが(上の画像は縮小して掲載しています)、横解像度が6637px、縦解像度が3787pxとかなり巨大な画像です。この巨大な画像を処理する際に高速なJAXが活躍するというわけです。まずファイルをダウンロードします(元画像は14.4MBあります)
!wget https://upload.wikimedia.org/wikipedia/commons/5/52/Hubble2005-01-barred-spiral-galaxy-NGC1300.jpg
これをOpenCVで読み込み、上から紫色のレイヤーをソフトライトでブレンドします。ソフトライトがなにかはここでは説明しませんが、Photoshopにあるような画像のブレンドモードの一種です。
import cv2
from jax import device_put
@jit
def galaxy():
original_bgr = cv2.imread("Hubble2005-01-barred-spiral-galaxy-NGC1300.jpg")
original_bgr = device_put(original_bgr)
original = original_bgr[:,:,::-1].astype(jnp.float32) / 255.0
# 紫のブレンドを作る
blend = jnp.ones(original.shape[:-1], dtype=jnp.float32)[..., None]
blend = blend * (jnp.array([235, 86, 230], dtype=jnp.float32).reshape(1, 1, -1) / 255.0)
# ソフトライトで合成
a = 2*original*blend + original**2*(1-2*blend)
b = 2*original*(1-blend) + jnp.sqrt(original)*(2*blend-1)
out = (blend<0.5)*a + (blend>=0.5)*b
# uint8に戻す
out = (out*255.0).astype(jnp.uint8)
return out
%timeit -n 100 galaxy().block_until_ready()
2行目のoriginal_bgr = device_put(original_bgr)
というのがポイント。PyTorchで配列をx = x.to(...)
とGPUに送るのに似ていますね。これはCPUでもあったほうがよくて、100回ループさせた結果、device_putなしが90.9msだったのに対し、device_putありが50msになりました。
なお、これをGPUに乗せると2.11msまで短縮されます。コンパイルいるとはいえ、6637×3787の画像処理が数msでできるというのは素晴らしいですね。
綺麗な紫色の銀河ができました! 結果のJAX配列をnp.array
としOpenCVなどに渡せば、画像を保存できます。
Colab Notebook
この記事で使ったColab Notebookはこちらです。
- CPU https://colab.research.google.com/drive/12YUh83fiueCcGdd3K0iO53HnjES3lf3T?usp=sharing
- GPU https://colab.research.google.com/drive/1pwtndzZ_EPBZQBf5t5mDIydcqsLjnbK_?usp=sharing
- TPU https://colab.research.google.com/drive/1WKX0XTmMspzhgSDErbjRMlVda4JFWIbn?usp=sharing
宣伝
技術書典10で新刊出します。NumPy関数だけで画像や動画を処理する(フォトショのような画像編集ソフトと同じ処理をする)本です。ここに出てきた「ソフトライトの合成」などをまさに扱っている本です(JAXは入ってないです)。実践演習形式で221問収録予定です。こちらもお楽しみに。
- Booth:https://koshian2.booth.pm/items/2462894
- 技術書典オンラインマーケット:https://techbookfest.org/product/5547509835366400
- その他まとめ:https://github.com/koshian2/numpy_book
- ISBN:978-4-910088-26-6
※Amazonでの物理書籍の取扱も調整中です
Discussion