反省はしても後悔はしない

Vim とか備忘録とか。それと関数型言語勉強中

poetry を使った jax 0.4.1 のインストール

poetry を使って jax をインストールしようとしてハマったのでメモ。

結論

以下のような pyproject.toml を書いて poetry install。

poetry==1.3.1, poetry-core==1.4.0 で動作を確認。

[[tool.poetry.source]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
default = false
secondary = false

[[tool.poetry.source]]
name = "jax_cuda"
url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
default = false
secondary = false

[tool.poetry.dependencies]
python = "^3.8.2"
jaxlib = {version = "0.4.1", source = "jax_cuda"}
jax = {version = "0.4.1", source = "jax"}

これで jaxlib==0.4.1+cuda11.cudnn86 と jax==0.4.1 が入る。cuda と cudnn のバージョンのところは環境依存かもしれない。 cuda 用の source とそうでない source を2つ書くのがポイント。 ただし、公式のインストール手順 https://github.com/google/jax#pip-installation-gpu-cuda には後者の URL は記載されておらず今後動かなくなる恐れがあるかもしれない。

cuda 用の jaxlib は Google が提供している URL から持ってくる必要があるが jax の方は PyPI にあるやつで動くはずなのだがなぜか poetry add jax だけだと

  Unable to find installation candidates for jax (0.4.1)

と言われてしまいインストールできなかった。