📸

uvでflash-attentionをinstallする

2024/08/26に公開1

Pythonのバージョン管理およびパッケージ管理を行うuvflash-attentionがinstallする方法をまとめました。
この記事ではflash-attentionについての説明は行いません。

uvに関しては、バージョン0.3.0以降の基本的な操作をまとめました。参考になれば嬉しいです。

https://zenn.dev/turing_motors/articles/594fbef42a36ee

なお今回の自分の環境におけるuvおよびCUDA nvccのバージョンは以下の通りです。

uv -V
> uv 0.3.1
nvcc -V
> nvcc: NVIDIA (R) Cuda compiler driver
> Copyright (c) 2005-2023 NVIDIA Corporation
> Built on Mon_Apr__3_17:16:06_PDT_2023
> Cuda compilation tools, release 12.1, V12.1.105
> Build cuda_12.1.r12.1/compiler.32688072_0

flash-attentionのinstall

まずCUDA依存のPyTorchをinstallします。uvにおけるindex-urlの指定はこちらにまとめています。
自分の環境(cu121)では以下のようなpyproject.tomlを作成して、uv syncを実行します。

pyproject.toml
[project]
name = "uv-gpu"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "torch==2.4.0+cu121",
    "transformers[torch]>=4.44.2",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.uv]
extra-index-url = ["https://download.pytorch.org/whl/cu121/"]

flash-attentionのinstallは、以下のコマンドで可能です。

uv add hatchling editables wheel
uv add flash-attn --no-build-isolation

実行後は以下のようなpyproject.tomlが作成されます。

pyproject.toml
[project]
name = "uv-gpu"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "torch==2.4.0+cu121",
    "transformers[torch]>=4.44.2",
    "hatchling>=1.22.5",
    "editables>=0.5",
    "wheel>=0.44.0",
    "flash-attn>=2.6.3",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.uv]
extra-index-url = ["https://download.pytorch.org/whl/cu121/"]

こちらのissueが参考になりました。

https://github.com/astral-sh/uv/issues/6402

上記の問題点

こちらでflash-attentionはinstallできるものの問題があります。
自分の理解が正しければ、そもそも--no-build-isolationでのパッケージのinstallは現在のPython環境でビルドされるため、先にPyTorchがinstallされた環境でないといけません。
そのため、上記のpyproject.tomlのみを使ってuv syncするとうまくflash-attentionがinstallできない問題があります。

実際、異なる環境で上記のファイルを用いてuv syncを実行すると以下のようなエラーが出力されます。

Using Python 3.12.5
Creating virtualenv at: .venv
⠦ flash-attn==2.6.3
error: Failed to download and build `flash-attn==2.6.3`
  Caused by: Build backend failed to determine extra requires with `build_wheel()` with exit status: 1
--- stdout:

--- stderr:
Traceback (most recent call last):
  File "<string>", line 14, in <module>
  File "/home/ubuntu/.cache/uv/builds-v0/.tmpCun63X/lib/python3.12/site-packages/setuptools/build_meta.py", line 325, in get_requires_for_build_wheel
    return self._get_build_requires(config_settings, requirements=['wheel'])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/.cache/uv/builds-v0/.tmpCun63X/lib/python3.12/site-packages/setuptools/build_meta.py", line 295, in _get_build_requires
    self.run_setup()
  File "/home/ubuntu/.cache/uv/builds-v0/.tmpCun63X/lib/python3.12/site-packages/setuptools/build_meta.py", line 487, in run_setup
    super().run_setup(setup_script=setup_script)
  File "/home/ubuntu/.cache/uv/builds-v0/.tmpCun63X/lib/python3.12/site-packages/setuptools/build_meta.py", line 311, in run_setup
    exec(code, locals())
  File "<string>", line 9, in <module>
ModuleNotFoundError: No module named 'packaging'
---

解決策1: --devをつけて2段階でsyncする

uvには開発依存関係としてパッケージを追加することができます。uv add--devをつけることで、tool.uv.dev-dependenciesに追加されます。
これを用いて2段階でuv syncを行うことでスムーズにflash-attentionをinstallできます。

まず、flash-attnのuv addを以下のように--devを追加します。hatchling editables wheelはそのままです。

uv add hatchling editables wheel
uv add flash-attn --no-build-isolation --dev

別環境でsyncするときは以下のコマンドで実行します。

uv sync --no-dev
uv sync --dev --no-build-isolation

これにより、先にPyTorch等の環境を構築して、flash-attention--no-build-isolationの引数のもとinstallすることが可能です。

この際のpyproject.tomlは以下のようになります。

pyproject.toml
[project]
name = "uv-gpu"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "torch==2.4.0+cu121",
    "transformers[torch]>=4.44.2",
    "hatchling>=1.22.5",
    "editables>=0.5",
    "wheel>=0.44.0",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.uv]
extra-index-url = ["https://download.pytorch.org/whl/cu121/"]
dev-dependencies = [
    "flash-attn>=2.6.3",
]

おわりに

uvでflash-attentionのinstallはでき、Development dependenciesを活用することでスムーズにinstallすることが可能です。他にもいい解決法があるかもしれませんし、私自身flash-attentionの使用頻度が高くないため、上記のアプローチでは問題があるかもしれません。
もし何かあればコメントいただけると助かります!

Discussion

shunk031shunk031

すばらしい記事をありがとうございました!これまで poetry を崇拝してきたのですが、flash attention を含む一部のライブラリがインストールできない問題が解決されそうにないのと、インストールが爆速になるという噂を聞いてようやく乗り換えました!uv 最高ですね

本記事では uv add --dev を使って dev-dependencies として flash attention をインストールしていますが、真に開発用の依存ライブラリである ruff、mypy、pytest といったモジュールも dev-dependencies として管理することが多いため、flash attention と一緒にこれら動作時には関係ないモジュールもインストールされてしまいます。

uv の v0.4.27 以降の uv だと Dependency groups というのが使えるみたいですので、以下のようにして開発用のライブラリと flash attention は別にインストールしても良さそうですね。

# dev 用の group を指定
uv add --group dev ruff mypy pytest

# flash-attn 用の group を指定
uv add --group flash-attn flash-attn --no-build-isolation

以下が出来上がった pyproject.toml です:

[project]
name = "flash-attn-test"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
    "editables>=0.5",
    "hatchling>=1.25.0",
    "setuptools>=75.3.0",
    "transformers[torch]>=4.30.0",
    "wheel>=0.44.0",
]

[dependency-groups]
dev = [
    "mypy>=1.0.0", 
    "pytest>=6.0.0", 
    "ruff>=0.1.5"
]
flash-attn = [
    "flash-attn>=2.6.3",
]

別環境で sync するときは以下のようにコマンドを実行すればよさそうです:

# dev と flash-attn のグループを抜いて sync する
uv sync --no-group dev --no-group flash-attn

# その後 dev のグループを sync する (実行環境の場合はなくても OK)
uv sync --group dev

# 最後に flash-attn のグループを sync する
uv sync --group flash-attn

追記: X 上で同様の方法が良さそうである旨教えていただきました!ありがとうございました!
https://x.com/colum2131/status/1852252153571922322