🔦

uvでflash-attentionをインストールするときのちょっとした注意?

2025/01/16に公開

TSUBAME 4.0でflash-attention使お〜と思ったときにちょっと躓いたので対処法など。

基本的にこの記事を参考にしています。ありがとうございます。
https://zenn.dev/ayutaso/articles/2a6353c657bbb7

以下の環境で確認しました。

$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:32:13_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0

$ uv -V
uv 0.5.18

躓きポイント

上記記事に従ってインストールを進めていたところ、主にこういうエラーで四苦八苦していました。
https://github.com/pytorch/pytorch/issues/111469
undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12の解決にLD_LIBRARY_PATHの指定が効く?と聞き試すも解決せず、結局以下の対処で落ち着きました。

  1. CUDAのバージョン
    TSUBAME 4.0はデフォルトでCUDA 12.3がロードされますが、執筆時点でのPyTorchの対応には12.3がなかったので12.1をloadしておきます。11.8でもよい?
    気づいたら12.3に戻ってた?のでちゃんと確認したほうがいいかもです。たぶん単に忘れてただけですが…

  2. インストール時のバージョン指定
    pyproject.tomlは以下のようにしました。

[project]
name = "project"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.9"
dependencies = []

[project.optional-dependencies]
build = ["torch==2.5.1+cu121", "setuptools", "packaging"]
compile = ["flash-attn"]

[tool.uv]
no-build-isolation-package = ["flash-attn"]
find-links = [
    "https://download.pytorch.org/whl/cu121/torch",
]

[[tool.uv.dependency-metadata]]
name = "flash-attn"
version = "2.6.3"
requires-dist = ["torch", "einops"]

下記スクラップを参考に、CUDA12.1対応のやつを明示しただけです。(ありがとうございます。)
https://zenn.dev/mjun0812/scraps/671db64dc42ffa

余談

当初いろいろとやり方が拙く、もしかすると解決前にごちゃごちゃやっていたキャッシュ云々が問題でありバージョン指定が必須というわけでもないかもしれませんが、念の為残しておきます。

Discussion