🔦
uvでflash-attentionをインストールするときのちょっとした注意?
TSUBAME 4.0でflash-attention使お〜と思ったときにちょっと躓いたので対処法など。
基本的にこの記事を参考にしています。ありがとうございます。
以下の環境で確認しました。
$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:32:13_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0
$ uv -V
uv 0.5.18
躓きポイント
上記記事に従ってインストールを進めていたところ、主にこういうエラーで四苦八苦していました。undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12
の解決にLD_LIBRARY_PATH
の指定が効く?と聞き試すも解決せず、結局以下の対処で落ち着きました。
-
CUDAのバージョン
TSUBAME 4.0はデフォルトでCUDA 12.3がロードされますが、執筆時点でのPyTorchの対応には12.3がなかったので12.1をloadしておきます。11.8でもよい?
気づいたら12.3に戻ってた?のでちゃんと確認したほうがいいかもです。たぶん単に忘れてただけですが… -
インストール時のバージョン指定
pyproject.tomlは以下のようにしました。
[project]
name = "project"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.9"
dependencies = []
[project.optional-dependencies]
build = ["torch==2.5.1+cu121", "setuptools", "packaging"]
compile = ["flash-attn"]
[tool.uv]
no-build-isolation-package = ["flash-attn"]
find-links = [
"https://download.pytorch.org/whl/cu121/torch",
]
[[tool.uv.dependency-metadata]]
name = "flash-attn"
version = "2.6.3"
requires-dist = ["torch", "einops"]
下記スクラップを参考に、CUDA12.1対応のやつを明示しただけです。(ありがとうございます。)
余談
当初いろいろとやり方が拙く、もしかすると解決前にごちゃごちゃやっていたキャッシュ云々が問題でありバージョン指定が必須というわけでもないかもしれませんが、念の為残しておきます。
Discussion