🤖
jax/flaxのインストール時のトラブルメモ
インストール
GPUを使う時はpipインストール時に指定しなければいけない
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
実際にはcudnnのバージョンに合わせてバージョン指定したほうがいい(後述)
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2023/9/11追記
公式のgithub見るとcudnnのバージョン指定不要になったっぽい?
pip install --upgrade pip
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
バージョンについて
cudnnバージョンは環境のcudnnのメジャーバージョンは合わせなければいけない
マイナーバージョンは環境のcudnnよりも古いjaxをインストールする
環境にインストールされているcudnnのバージョンはヘッダファイルを見る
$ cat /usr/include/cudnn_version.h
---省略---
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 5
#define CUDNN_PATCHLEVEL 0
---省略---
上の例の場合、jax[cuda11_cudnn82]をインストールしなければいけない
GPUのメモリ
デフォルトではGPUメモリの9割をプリアロケートする。
チュートリアルやサンプルコードが動かないとかの報告がある。
特に、Getting startedのコードはTFDSでMNISTをロードする際にTensorFlowがGPUメモリをアロケートしてしまうので、TensorFlow側も対処が必要
JAXの設定
下記のどちらかを.bashrcでexportするか、実行時に指定する
export XLA_PYTHON_CLIENT_PREALLOCATE=false # preallocateしない場合
export XLA_PYTHON_CLIENT_MEM_FRACTION=.XX # 80-85%でうまくいくらしい。デフォルトは90%
TensorFlowの設定
下記をコード内で実行しておく
tf.config.experimental.set_visible_devices([], "GPU")
参考
Discussion