👏

uvでflash-attentionをinstallする2 【uvバージョン0.5.18】

に公開

Pythonのバージョン管理およびパッケージ管理を行うuvflash-attentionをinstallする方法をまとめました。バージョン0.3時点の内容が以下の記事で紹介されていますが、2025年1月13日時点では公式ドキュメントにuvでflash-attentionをinstallする方法が紹介されていたので紹介します。以下の内容は公式ドキュメントの抜粋のような内容です。
https://zenn.dev/colum2131/articles/342b7bdb20c54e

なお、uvの更新は非常に活発なため、この記事の内容から公式ドキュメントが既に更新されている可能性も十分ありますので、公式ドキュメントを読みに行くことを推奨しておきます。以下が公式ドキュメントの該当箇所です。

https://docs.astral.sh/uv/concepts/projects/config/#build-isolation

今回の環境におけるuvおよびCUDAのバージョンは以下の通りです。

$ uv -V
uv 0.5.18
$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

flash-attentionのinstall

flash-attnのようなパッケージは、依存関係の解決フェーズ(lockfile作成時)でもビルド依存関係を必要とします。
そこで、uvバージョン0.5.18ではflash-attnに対して、依存関係のメタデータを事前に提供することで、依存関係解決フェーズ中にパッケージをビルドする必要をなくすことができます。

依存関係のメタデータの記述

メタデータは、pyproject.toml に以下のように記述します

[project]
name = "project"
version = "0.1.0"
description = "..."
readme = "README.md"
requires-python = ">=3.11"
dependencies = []

[project.optional-dependencies]
build = ["torch", "setuptools", "packaging"]
compile = ["flash-attn"]

[tool.uv]
no-build-isolation-package = ["flash-attn"]

[[tool.uv.dependency-metadata]]
name = "flash-attn"
version = "2.6.3"
requires-dist = ["torch", "einops"]

インストール

メタデータを提供しておくと、インストールは以下のコマンドで実行できます。

$ uv sync --extra build
$ uv sync --extra build --extra compile

実行結果

$ uv sync --extra build
Resolved 27 packages in 0.77ms
Installed 24 packages in 258ms
 + filelock==3.16.1
 + fsspec==2024.12.0
 + jinja2==3.1.5
 + markupsafe==3.0.2
 + mpmath==1.3.0
 + networkx==3.4.2
 + nvidia-cublas-cu12==12.4.5.8
 + nvidia-cuda-cupti-cu12==12.4.127
 + nvidia-cuda-nvrtc-cu12==12.4.127
 + nvidia-cuda-runtime-cu12==12.4.127
 + nvidia-cudnn-cu12==9.1.0.70
 + nvidia-cufft-cu12==11.2.1.3
 + nvidia-curand-cu12==10.3.5.147
 + nvidia-cusolver-cu12==11.6.1.9
 + nvidia-cusparse-cu12==12.3.1.170
 + nvidia-nccl-cu12==2.21.5
 + nvidia-nvjitlink-cu12==12.4.127
 + nvidia-nvtx-cu12==12.4.127
 + packaging==24.2
 + setuptools==75.8.0
 + sympy==1.13.1
 + torch==2.5.1
 + triton==3.1.0
 + typing-extensions==4.12.2
$ uv sync --extra build --extra compile
Resolved 27 packages in 0.65ms
Installed 2 packages in 14ms
 + einops==0.8.0
 + flash-attn==2.7.3

おわりに

日本語でググって古い情報にあたって困ることが多いので、特に進化の速いソフトウェアについては素直に公式ドキュメント読むのが良いなと思いました。とはいえ、ググって公式ドキュメントの該当箇所が張ってある記事があるだけでも嬉しいかなと思って書いた記事でした。

https://docs.astral.sh/uv/

Discussion