🐍

UV + Pytorch Nightly(CUDA12.8)を使う

に公開

Pytorchのバージョン

詳細は以下の記事を確認してください。
https://zenn.dev/cynagenautes/articles/0566bd145c350c
torch-2.8.0.dev20250524以降のバージョンを使用することで問題の回避が可能です。

UVでのパッケージ管理

uv add torchでPytorchを追加する際はPyPlが参照され、このときのCUDAのバージョンは時期によって変動してしまうので管理がしづらいです。これはUVのドキュメントにも記述されており、pyproject.tomlへ追加で記述を行うことで、Pytorchが別途提供するCUDAバージョンごとのindex-urlから取得することが可能になります。
https://docs.astral.sh/uv/guides/integration/pytorch/

下記はCUDA12.8.1でRTX 50シリーズGPUでの速度問題が改善されたビルドを使用する場合のコンフィグです。ファイルに項目を追加した上でuv syncを実行することでインストール可能です。

pytorch-tritonを追加するのが重要で、この依存関係はPytorchのindexからしか解決することができません。バージョンにGitのコミットハッシュが含まれており、PyPlで解決ができないことが理由だと考えられます。

dependencies = [
    "pytorch-triton==3.3.1+gitc8757738",
    "torch==2.8.0.dev20250524+cu128",
    "torchvision==0.22.0.dev20250524+cu128",
]

[tool.uv.sources]
torch = [{ index = "pytorch-cu128" }]
torchvision = [{ index = "pytorch-cu128" }]
pytorch-triton = [{ index = "pytorch-cu128" }]

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = true

Discussion