JAX入門~高速なNumPyとして使いこなすためのチュートリアル~

18 min読了の目安(約16800字TECH技術記事

TensorFlow Advent Calendar 2020 10日目の記事です。空いてたので当日飛び入りで参加しました。

この記事では、TensorFlowの関連ライブラリである「JAX」について初歩的な使い方、ハマりどころ、GPU・TPUでの使い方や、画像処理への応用について解説します。

JAXとは

https://github.com/google/jax

Google製のライブラリで、AutogradとXLAからなる、機械学習のための数値計算ライブラリ。簡単に言うと「自動微分に特化した、GPUやTPUに対応した高速なNumPy」。NumPyとほとんど同じ感覚で書くことができます。自動微分については解説が多いので、この記事では単なる高速なNumPyの部分を中心に書いていきます。

関連記事

GPU対応のNumPyという観点では、似たライブラリとしてPFN製のCuPyや、AnacondaがスポンサーとなっているNambdaもあります。

配列の初期化

最初はCPUに限定して書きます。JAXの導入はとてもシンプルで、あたかもNumPyのように使うことができます。

import jax.numpy as jnp

# NumPyではnp.arange(25, dtype=np.float32).reshape(5, 5)
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
print(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24.]]

JAXでのNumPy関数

.block_until_ready()

NumPy関数はnpjnpに書き換えるだけ。ただし、JAXでは非同期処理で計算されるため、計算の最後に.block_until_ready()を追加します。

# NumPyではnp.dot(x, x.T)
x_gram = jnp.dot(x, x.T).block_until_ready()
print(x_gram)
[[  30.   80.  130.  180.  230.]
 [  80.  255.  430.  605.  780.]
 [ 130.  430.  730. 1030. 1330.]
 [ 180.  605. 1030. 1455. 1880.]
 [ 230.  780. 1330. 1880. 2430.]]

特に理由がなければ.block_until_ready()はJAXの計算の最後のみ入れればOKです。

y = x + 1
x_gram = jnp.dot(x, y.T).block_until_ready() # 最後だけブロッキングを入れればOK

パフォーマンス・チューニング

このまま使ってもJAX本来の性能を引き出せないので、jitでXLAコンパイルします。メソッドを@jitデコレーターで囲むか、jitでメソッド全体をラップします。.block_until_ready()jitの外側に出します。

デコレーターで@jit

from jax import jit

@jit
def static_jax_dot():
    x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
    x_gram = jnp.dot(x, x.T)
    return x_gram

static_jax_dot().block_until_ready()
DeviceArray([[  30.,   80.,  130.,  180.,  230.],
             [  80.,  255.,  430.,  605.,  780.],
             [ 130.,  430.,  730., 1030., 1330.],
             [ 180.,  605., 1030., 1455., 1880.],
             [ 230.,  780., 1330., 1880., 2430.]], dtype=float32)

メソッドをjitでラップする

メソッド全体をラップする書き方は次の通りです。関数をjitで囲んで呼び出すため、()が2回出てきます。Kerasのレイヤーの書き方に似ていますね。

def static_jax_dot_nojit():
    x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
    x_gram = jnp.dot(x, x.T)
    return x_gram

jit(static_jax_dot)().block_until_ready() # jitで関数をラップして呼び出す

ダメな例:.block_until_ready()をjitの内側に入れる

jitの内側に.block_until_ready()を入れてはいけません。エラーになります。

# ダメな例

@jit
def static_jax_dot_badexample():
    x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
    x_gram = jnp.dot(x, x.T)
    return x_gram.block_until_ready()

static_jax_dot_badexample()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-34-d4f187296547> in <module>()
      7     return x_gram.block_until_ready()
      8 
----> 9 static_jax_dot_badexample()

13 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in __getattr__(self, name)
    529 
    530     try:
--> 531       attr = getattr(self.aval, name)
    532     except KeyError as err:
    533       raise AttributeError(

AttributeError: 'ShapedArray' object has no attribute 'block_until_ready'

jitでラップした関数に引数を渡すときは注意

jitでラップした関数に、引数を渡す際には注意が必要です。先程のstatic_jax_dot()内のxサイズを引数で可変にするようなケースです。

ダメな例:何も考えずにラップされた関数に引数を渡す

何も考えずに引数を持った関数をjitで囲むと次のようなエラーになります。

# ダメな例
@jit
def variable_jax_dot_badexample(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

variable_jax_dot_badexample(5)
---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-36-1aa82304f1fa> in <module>()
      6     return x_gram
      7 
----> 8 variable_jax_dot_badexample(5)

15 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in raise_concretization_error(val, context)
    881          "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
    882           f"Encountered tracer value: {val}")
--> 883   raise ConcretizationTypeError(msg)
    884 
    885 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function variable_jax_dot_badexample at <ipython-input-36-1aa82304f1fa>:2, this concrete value was not available in Python because it depends on the value of the arguments to variable_jax_dot_badexample at <ipython-input-36-1aa82304f1fa>:2 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

jnp.arange(...)の引数でトレース値を検出しました。これは非常に計算コストが高いからやめてくれ」という趣旨です。トレーシングについてのエラーはTensorFlowで見たことある方もいると思います。同じ理由です。詳細は下記のドキュメントにあります。

https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error

良い例1:static_argnumsの引数を使ってjitでラップ

jitでラップするケース。static_argnumsというキーワード引数を指定します。

# 良い例1
def variable_jax_dot(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

jit(variable_jax_dot, static_argnums=(0,))(5).block_until_ready()

キーワード引数static_argnumsの意味は、ラップする関数(variable_jax_dot)の引数の何番目に固定値の引数があるかということです。sizeはインデックス0の引数なので、ここでは(0, )としています。引数が複数の場合次のようになります。

def multiple_inputs(x, size, axis):
    y = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    z = jnp.dot(jnp.sin(x), y)
    return jnp.stack([y, z], axis=axis)

x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
jit(multiple_inputs, static_argnums=(1, 2))(x, x.shape[0], 0)
DeviceArray([[[  0.        ,   1.        ,   2.        ,   3.        ,
                 4.        ],
              [  5.        ,   6.        ,   7.        ,   8.        ,
                 9.        ],
              [ 10.        ,  11.        ,  12.        ,  13.        ,
                14.        ],
              [ 15.        ,  16.        ,  17.        ,  18.        ,
                19.        ],
              [ 20.        ,  21.        ,  22.        ,  23.        ,
                24.        ]],

             [[  0.28107762,   1.4161646 ,   2.55125   ,   3.686337  ,
                 4.821422  ],
              [ 28.255533  ,  29.075655  ,  29.895779  ,  30.715904  ,
                31.536028  ],
              [ 15.748972  ,  15.079163  ,  14.409353  ,  13.739543  ,
                13.069733  ],
              [-19.320755  , -20.52088   , -21.721     , -22.921124  ,
               -24.121246  ],
              [-26.71011   , -26.721159  , -26.732206  , -26.743258  ,
               -26.754307  ]]], dtype=float32)

良い例2:partialでjitをデコレートする

デコレーターの場合は、partialというデコレーターを使います。

参考:jit decorator can't accept arguments #184

# 良い例2
from jax import partial

@partial(jit, static_argnums=(0,))
def variable_jax_dot_deco(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

variable_jax_dot_deco(5).block_until_ready()

デコレーター@partialjit関数をラップし、static_argnumsを指定するのがポイント。指定方法は先程と同じ。

パフォーマンス比較~CPU~

CPU環境で、

  • ただのNumPy
  • XLAコンパイルしないJAX
  • XLAコンパイルしたJAX

を配列のサイズを変えて比較してみます。環境はColabのCPUインスタンスを用いました。ドット積はサイズを大きくすると計算が重かったので、要素間の計算で試してみます。

  • NumPy v1.18.5
  • JAX v0.2.6
import numpy as np
import jax.numpy as jnp
from jax import jit, partial

# (size, size)の行列を作ってMod計算
@partial(jit, static_argnums=(0,))
def jax_jit_mod(size):
    x = jnp.arange(size, dtype=jnp.int32)
    mat = x[None, :] * x[:, None] # (size, size)
    return mat % 256

def jax_nojit_mod(size):
    x = jnp.arange(size, dtype=jnp.int32)
    mat = x[None, :] * x[:, None]
    return mat % 256

def numpy_mod(size):
    x = np.arange(size, dtype=np.int32)
    mat = x[None, :] * x[:, None]
    return mat % 256

for i in range(4):
    size = 10**(i+1)
    repeat = 10**(4-i)
    print("size =", size, "repeat =", repeat)
    %timeit -n {repeat} numpy_mod(size)
    %timeit -n {repeat} jax_nojit_mod(size).block_until_ready() # jitなしJAX
    %timeit -n {repeat} jax_jit_mod(size).block_until_ready() # jitありJAX

結果は次のようになりました。

サイズ 反復回数 NumPy jitなしJAX jitありJAX NumPy÷jitあり
10 10000 6.07 µs 2.25 ms 63 µs 0.10
100 1000 67.9 µs 2.87 ms 138 µs 0.49
1000 100 6.33 ms 9.41 ms 781 µs 8.10
10000 10 517 ms 606 ms 48 ms 10.77

配列のサイズが100まではNumPyが高速でしたが、1000以降は「jitありJAX」が圧勝しました。このケースでは「jitなしJAX」を使う意味がありませんでした。「NumPy÷jitあり」はNumPyの処理時間をjitありJAXの処理時間で割ったもので、この値が大きいほどJAXが有利です。巨大な配列に対して高速に計算したいのならJAXがとてもおすすめということになります。画像や動画処理で効いてきそうですね。

パフォーマンス比較~GPU~

JAXはGPUでも使えます。GPU環境では特にコードを追加する必要がなく、デフォルトでGPUを使ってくれます

同じ比較をしてみました。

GPUでint32の場合

  • 環境:ColabのGPU、ランタイム割当はTesla P-100
サイズ 反復回数 NumPy jitありJAX NumPy÷jitあり
10 10000 11.6 µs 114 µs 0.10
100 1000 69 µs 109 µs 0.63
1000 100 5.62 ms 118 µs 47.63
10000 10 486 ms 738 µs 658.54

傾向はCPUと変わりませんが、サイズが1万のときは桁違いに速くNumPyの658倍の速度を叩き出しました。サイズが小さいケースで結果が安定しないのは、メモリコピーなど計算以外のボトルネックがあるからだと思われます。ディープラーニングでも同じことは起こります。

ディープラーニングのフレームワークが浮動小数点数中心なせいか、JAXは特に整数計算が強いような印象があります。float32でも確かめてみます。

GPUでfloat32の場合

サイズ 反復回数 NumPy jitありJAX NumPy÷jitあり
10 10000 7.21 µs 112 µs 0.06
100 1000 115 µs 119 µs 0.97
1000 100 10.1 ms 179 µs 56.42
10000 10 939 ms 5.12 ms 183.40

float32でも強いことは強いです。サイズ1万で183倍で、int32の658倍ほどは良くならなかったです。GPUのスペックによっても変わるでしょう。

デバイスを明示して実行する

これは実験的な機能で変更される可能性がありますが、JITはデバイスを明示して計算可能です。GPUが利用可能ならデフォルトでGPUを使いますが、CPUで実行することもできます。

import jax

def dot_function():
    x = jnp.arange(1000**2, dtype=jnp.float32).reshape(1000, 1000)
    return jnp.dot(x, x.T)

# devicesは実験的機能で変更される可能性がある
%timeit -n 100 jit(dot_function, device=jax.devices("cpu")[0])().block_until_ready()
# デフォルト。deviceを指定しない場合と一緒
%timeit -n 100 jit(dot_function, device=jax.devices("gpu")[0])().block_until_ready() 
# 100 loops, best of 3: 33.4 ms per loop
# 100 loops, best of 3: 618 µs per loop

CPUを明示した場合は明らかに遅くなっているのが確認できます。

Colab TPUでのJAX

JAXはTPUでも利用可能です。Colab TPUのでは、以下のコードを事前に追加する必要があります。

import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
    resp = requests.post(url)

from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

このコードは、JAX Cloud TPU Previewにある、Lorentz ODE Solverのコードからコピペしたものです。

TPUのドライバーが少し古かったり「ん?」と思うことはありますが、一応これで動きます。今後の仕様変更によっては変わる可能性が大いにあります。

パフォーマンス比較~TPU~

サイズ 反復回数 NumPy jitありJAX NumPy÷jitあり
10 10000 5.21 µs 1.83 ms 0.00
100 1000 97.2 µs 1.76 ms 0.06
1000 100 9.1 ms 1.89 ms 4.81
10000 10 780 ms 4.39 ms 177.68

int32型での比較です。GPUに比べて少し立ち上がりが遅いように見えますが、TPUの場合はコンパイル等最初のボトルネックが大きかったりするのでこれは仕方ないでしょう。

ただし、これはおそらくTPUの1コア分の計算でTPU本来の性能ではないと思われます。コア単位で分散計算させればもっと速度は出るはずです。詳細は省略しますが、pmapを活用すると複数のTPUコアに分散させられるとのことです。詳しくは以下の資料を見てください。

Pmap Cookbook

NumPy配列と併用する際は、device_putを使うべし

OpenCVで読み込んだ画像など、NumPy配列とJAXを併用するときは、device_putでNumPy配列をデバイスに転送しておくと高速に動作します。

一例として、銀河の画像を用います。これはWikipediaにあるハッブル宇宙望遠鏡が撮影した銀河の画像ですが(上の画像は縮小して掲載しています)、横解像度が6637px、縦解像度が3787pxとかなり巨大な画像です。この巨大な画像を処理する際に高速なJAXが活躍するというわけです。まずファイルをダウンロードします(元画像は14.4MBあります)

!wget https://upload.wikimedia.org/wikipedia/commons/5/52/Hubble2005-01-barred-spiral-galaxy-NGC1300.jpg

これをOpenCVで読み込み、上から紫色のレイヤーをソフトライトでブレンドします。ソフトライトがなにかはここでは説明しませんが、Photoshopにあるような画像のブレンドモードの一種です。

import cv2
from jax import device_put

@jit
def galaxy():
    original_bgr = cv2.imread("Hubble2005-01-barred-spiral-galaxy-NGC1300.jpg")
    original_bgr = device_put(original_bgr)
    original = original_bgr[:,:,::-1].astype(jnp.float32) / 255.0
    # 紫のブレンドを作る
    blend = jnp.ones(original.shape[:-1], dtype=jnp.float32)[..., None]
    blend = blend * (jnp.array([235, 86, 230], dtype=jnp.float32).reshape(1, 1, -1) / 255.0)
    # ソフトライトで合成
    a = 2*original*blend + original**2*(1-2*blend)
    b = 2*original*(1-blend) + jnp.sqrt(original)*(2*blend-1)
    out = (blend<0.5)*a + (blend>=0.5)*b
    # uint8に戻す
    out = (out*255.0).astype(jnp.uint8)
    return out

%timeit -n 100 galaxy().block_until_ready()

2行目のoriginal_bgr = device_put(original_bgr)というのがポイント。PyTorchで配列をx = x.to(...)とGPUに送るのに似ていますね。これはCPUでもあったほうがよくて、100回ループさせた結果、device_putなしが90.9msだったのに対し、device_putありが50msになりました。

なお、これをGPUに乗せると2.11msまで短縮されます。コンパイルいるとはいえ、6637×3787の画像処理が数msでできるというのは素晴らしいですね。

綺麗な紫色の銀河ができました! 結果のJAX配列をnp.arrayとしOpenCVなどに渡せば、画像を保存できます。

Colab Notebook

この記事で使ったColab Notebookはこちらです。

宣伝

技術書典10で新刊出します。NumPy関数だけで画像や動画を処理する(フォトショのような画像編集ソフトと同じ処理をする)本です。ここに出てきた「ソフトライトの合成」などをまさに扱っている本です(JAXは入ってないです)。実践演習形式で221問収録予定です。こちらもお楽しみに。

※Amazonでの物理書籍の取扱も調整中です