🤗

Rust で transformer を動かしてみて躓いたこと6選

2025/01/16に公開

機械学習のフレームワークといえば Python の Hugging Face 🤗 が有名ですが,それを Rust でも使えるようにしたい!ということで,🤗 による公式の candle というプロジェクトがあります.

Rust 好きとしてはぜひとも触ってみたいと思い試してみたのですが,思いの外躓いた点が多かったので,知見として共有しようと思います.

cuda.h が見つからない!

Anaconda で諸々インストールしている場合,CUDA の C(++) ヘッダは $CONDA_PREFIX/include 以下にあります.
もちろん conda env で仮想環境を立てているのであれば,その env の /include 以下となります.

私はそれに気付かず,cuda.h が見つからないというエラーを貰いました.
これに対処するには,

  1. conda activate MY_ENV
  2. CUDA_ROOT=$CONDA_PREFIX

によって正しくヘッダファイルの場所を指定する必要があります.

libcublas.so が見つからない!

ヘッダファイルが見つからなければ,当然 shared object も見つかりません.LD_LIBRARY_PATH を指定してあげましょう.

LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH

libtinfo.so.6: no version information available

ただし,上二つの環境変数を export してエディタを開いてはいけません!!!

※私の環境(NeoVim + rust.vim + rust-analyzer)の場合は,です.他の環境は検証していないため,分かりません.

対処方法としては,

  • 環境変数を export せず,毎回 CUDA_ROOT=... LD_LIBRARY_PATH=... cargo run と実行する
  • 環境変数を export し,別プロセスの子としてエディタを開く

のいずれかが必要です.

私自身原因は把握していませんが,.vimrc

let g:python3_host_prog  = '/.../anaconda3/bin/python3'

と設定しているのが CUDA_ROOT=$CONDA_PREFIX と衝突するのか,ファイル保存時に

/bin/zsh: /.../anaconda3/envs/research/lib/libtinfo.so.6: no version information available (required by /bin/zsh)
bash: /.../anaconda3/envs/research/lib/libtinfo.so.6: no version information available (required by bash)
bash: /.../anaconda3/envs/research/lib/libtinfo.so.6: no version information available (required by bash)
bash: /.../anaconda3/envs/research/lib/libtinfo.so.6: no version information available (required by bash)

というエラーメッセージが Rust ファイルの先頭に(!!)挿入されてしまいます.

ファイル保存時にということは rust.vim

hook_add = '''
let g:rustfmt_autosave = 1
'''

という設定が怪しいのですが,詳しい原因は謎のままです.

Hub からのダウンロード時に DNS エラーが出る!

candle には,hf-hub を通じて Hugging Face Hub からモデルをダウンロードする,という機能があります.

さっそくサンプルの通りに google/t5-small を試そうと思ったら……

Error: RequestError(reqwest::Error { kind: Request, url: "https://huggingface.co/t5-small/resolve/main/config.json", source: hyper_util::client::legacy::Error(Connect, ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: Temporary failure in name resolution" })) })

DNS の名前解決に失敗したというエラーが.
これは hf-hub の内部でデフォルトで用いられている native-tls私の環境だと動作しなかったためです.
代わりに rustls を用いるように修正すると,動作するようになります.

hf-hub = { version = "0.4", default-features = false, features = ["tokio", "rustls-tls"] }

同じバージョンの candle_core が衝突する!

さて,ここまでセットアップしてようやく t5-small を試せるぞ!と意気込んで,candle-examples を試そうとしたところ,

error[E0308]: arguments to this function are incorrect
   --> src/main.rs:32:27
    |
32  |         let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weights, DTYPE, device)? };
    |                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^           -----  ------ expected `candle_transformers::models::mimi::candle_core::Device`, found `candle_core::Device`
    |                                                                         |
    |                                                                         expected `candle_transformers::models::mimi::candle_core::DType`, found `candle_core::DType`
    |
    = note: `candle_core::DType` and `candle_transformers::models::mimi::candle_core::DType` have similar names, but are actually distinct types
note: `candle_core::DType` is defined in crate `candle_core`
   --> /.../cargo/git/checkouts/candle-0c2b4fa9e5801351/17cbbe4/candle-core/src/dtype.rs:8:1
    |
8   | pub enum DType {error[E0308]: arguments to this function are incorrect
   --> src/main.rs:32:27
    |
32  |         let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weights, DTYPE, device)? };
    |                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^           -----  ------ expected `candle_transformers::models::mimi::candle_core::Device`, found `candle_core::Device`
    |                                                                         |
    |                                                                         expected `candle_transformers::models::mimi::candle_core::DType`, found `candle_core::DType`
    |
    = note: `candle_core::DType` and `candle_transformers::models::mimi::candle_core::DType` have similar names, but are actually distinct types
note: `candle_core::DType` is defined in crate `candle_core`
   --> /.../cargo/git/checkouts/candle-0c2b4fa9e5801351/17cbbe4/candle-core/src/dtype.rs:8:1
    |
8   | pub enum DType {
    | ^^^^^^^^^^^^^^
note: `candle_transformers::models::mimi::candle_core::DType` is defined in crate `candle_core`
   --> /.../cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.8.2/src/dtype.rs:8:1
    |
8   | pub enum DType {
    | ^^^^^^^^^^^^^^
    = note: perhaps two different versions of crate `candle_core` are being used?
    = note: `candle_core::Device` and `candle_transformers::models::mimi::candle_core::Device` have similar names, but are actually distinct types
note: `candle_core::Device` is defined in crate `candle_core`
   --> /.../cargo/git/checkouts/candle-0c2b4fa9e5801351/17cbbe4/candle-core/src/device.rs:16:1
    |
16  | pub enum Device {
    | ^^^^^^^^^^^^^^^
note: `candle_transformers::models::mimi::candle_core::Device` is defined in crate `candle_core`
   --> /.../cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.8.2/src/device.rs:16:1
    |
16  | pub enum Device {
    | ^^^^^^^^^^^^^^^
    = note: perhaps two different versions of crate `candle_core` are being used?
note: associated function defined here
   --> /.../cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-nn-0.8.2/src/var_builder.rs:515:19
    |
515 |     pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
    |                   ^^^^^^^^^^^^^^^^^^^^^^^

    | ^^^^^^^^^^^^^^
note: `candle_transformers::models::mimi::candle_core::DType` is defined in crate `candle_core`
   --> /.../cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.8.2/src/dtype.rs:8:1
    |
8   | pub enum DType {
    | ^^^^^^^^^^^^^^
    = note: perhaps two different versions of crate `candle_core` are being used?
    = note: `candle_core::Device` and `candle_transformers::models::mimi::candle_core::Device` have similar names, but are actually distinct types
note: `candle_core::Device` is defined in crate `candle_core`
   --> /.../cargo/git/checkouts/candle-0c2b4fa9e5801351/17cbbe4/candle-core/src/device.rs:16:1
    |
16  | pub enum Device {
    | ^^^^^^^^^^^^^^^
note: `candle_transformers::models::mimi::candle_core::Device` is defined in crate `candle_core`
   --> /.../cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.8.2/src/device.rs:16:1
    |
16  | pub enum Device {
    | ^^^^^^^^^^^^^^^
    = note: perhaps two different versions of crate `candle_core` are being used?
note: associated function defined here
   --> /.../cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-nn-0.8.2/src/var_builder.rs:515:19
    |
515 |     pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
    |                   ^^^^^^^^^^^^^^^^^^^^^^^

このようなエラーが!

内容としては,関数の引数として candle_transformers::...::{Device,DType} が要求されているのに,candle_core::{Device,DType} を渡してしまっているよ!というものです.

ですが,ここにはおかしな点が二つあります.

  1. VarBuilder が要求しているのは間違いなく candle_core::DType のはずである.
  2. candle_transformers:0.8.2 が依存している candle_core と,Cargo.toml に書いた candle_core (v0.8.2) は同じバージョンのはずなのに,perhaps two different versions of crate `candle_core` are being used? と言われている.決して different versions ではありません,と声を大にして言いたい.

発生したエラーを嘆いていても仕方ないので,解決策を考えます.
とはいうものの,これは Rust コンパイラの問題であり,根本的な対策は存在しないように思えます.

なので私は,

  1. candleローカルに git clone してくる.
  2. Cargo.toml に,path = "..." として指定する.

という方法を採りました.

candle-core = { path = "./lib/candle/candle-core", features = ["cuda"] }
candle-nn = { path = "./lib/candle/candle-nn" }
candle-transformers = { path = "./lib/candle/candle-transformers" }

前時代のパッケージ管理に戻った気分です.

softmax-last-dim が実装されていない!

上記のような力技で解決した後,今度こそ T5 のサンプルを,と実行すると,次のそっけないエラーが発生しました.

Error: no cuda implementation for softmax-last-dim

これは,上記の candle-nn のデフォルトの features において,CUDA を用いようとすると fallback の実装が呼び出されてしまったのが原因です.
なので,しっかり candle-nn にも cuda feature を指定しなければなりません.

candle-nn = { path = "./lib/candle/candle-nn", features = ["cuda"] }

動いた!!!!

以上の頑張りの結果,無事 T5 を動かすことができました.

$ CUDA_ROOT=$CONDA_PREFIX LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH cargo r --release
    Finished `release` profile [optimized] target(s) in 0.08s
     Running `target/release/candletest`
 Eine schöne Kerze.
6 tokens generated (364.30 token/s)

感慨も一入といったところで,もっと色々試してみようと思います.

動くサンプルはこちら

https://github.com/naughie/test-candle-rs

Discussion