🐍

google/jax を Poetry でインストールする

2022/10/03に公開

背景

google/jax が PEP 503 に対応していなくて困った
https://github.com/google/jax/issues/5410

Poetry 側にも issue がある

https://github.com/python-poetry/poetry/issues/5481

どうやら pip の --find-links 相当のものが 1.2.0 から実装されたらしい
https://github.com/python-poetry/poetry/issues/1391
https://github.com/python-poetry/poetry/pull/5517

このバージョン以降ならインストールできそう

インストール方法

$ poetry source add jax https://storage.googleapis.com/jax-releases/jax_releases.html
$ poetry add --source jax jaxlib
$ poetry add jax

これで 0.3.20+cuda11.cudnn82 の jaxlib がインストールされて、jax から cuda を認識できるようになった

仕組み

jax[cuda]jax[tpu]extras という仕組みを使ったインストール方法

https://peps.python.org/pep-0508/#extras

パッケージ側では setup.pyextras_require を書くことで指定できる

https://github.com/google/jax/blob/09720b9bcbf07c1318774033f1cb4f4751e37895/setup.py#L78-L98

jax の仕組みは、extras に応じて jaxlib の実装を切り替えることによってバックエンドを使い分けているっぽい

Discussion