🤖

jax/flaxのインストール時のトラブルメモ

2023/01/17に公開約1,300字

インストール

GPUを使う時はpipインストール時に指定しなければいけない

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

実際にはcudnnのバージョンに合わせてバージョン指定したほうがいい(後述)

pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

バージョンについて

cudnnバージョンは環境のcudnnのメジャーバージョンは合わせなければいけない
マイナーバージョンは環境のcudnnよりも古いjaxをインストールする

環境にインストールされているcudnnのバージョンはヘッダファイルを見る

$ cat /usr/include/cudnn_version.h

---省略---
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 5
#define CUDNN_PATCHLEVEL 0
---省略---

上の例の場合、jax[cuda11_cudnn82]をインストールしなければいけない

GPUのメモリ

デフォルトではGPUメモリの9割をプリアロケートする。
チュートリアルやサンプルコードが動かないとかの報告がある。

特に、Getting startedのコードはTFDSでMNISTをロードする際にTensorFlowがGPUメモリをアロケートしてしまうので、TensorFlow側も対処が必要

JAXの設定

下記のどちらかを.bashrcでexportするか、実行時に指定する

export XLA_PYTHON_CLIENT_PREALLOCATE=false # preallocateしない場合
export XLA_PYTHON_CLIENT_MEM_FRACTION=.XX # 80-85%でうまくいくらしい。デフォルトは90%

TensorFlowの設定

下記をコード内で実行しておく

tf.config.experimental.set_visible_devices([], "GPU")

参考

https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
https://github.com/google/jax/issues/8746
https://tech.yellowback.net/posts/jax-oom

Discussion

ログインするとコメントできます