poetry を使った jax 0.4.1 のインストール
poetry を使って jax をインストールしようとしてハマったのでメモ。
結論
以下のような pyproject.toml を書いて poetry install。
poetry==1.3.1, poetry-core==1.4.0 で動作を確認。
[[tool.poetry.source]] name = "jax" url = "https://storage.googleapis.com/jax-releases/jax_releases.html" default = false secondary = false [[tool.poetry.source]] name = "jax_cuda" url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" default = false secondary = false [tool.poetry.dependencies] python = "^3.8.2" jaxlib = {version = "0.4.1", source = "jax_cuda"} jax = {version = "0.4.1", source = "jax"}
これで jaxlib==0.4.1+cuda11.cudnn86 と jax==0.4.1 が入る。cuda と cudnn のバージョンのところは環境依存かもしれない。 cuda 用の source とそうでない source を2つ書くのがポイント。 ただし、公式のインストール手順 https://github.com/google/jax#pip-installation-gpu-cuda には後者の URL は記載されておらず今後動かなくなる恐れがあるかもしれない。
cuda 用の jaxlib は Google が提供している URL から持ってくる必要があるが jax の方は PyPI にあるやつで動くはずなのだがなぜか poetry add jax だけだと
Unable to find installation candidates for jax (0.4.1)
と言われてしまいインストールできなかった。