🕯️

Pytorchのlibtorchをrustで使うtch-rsに関するメモ

2023/08/31に公開

はじめに

Rust言語がマイブームになっているのでDeep Learningの学習/推論もやりたくなりました。
Deep Learning系のRust用crateはいくつかあるようです。

burnはなかなか面白そうで,backendでtchやwgpuやndarrayを使うことを選択できるみたいです。burnは少し試してみましたがTensorの操作などに一癖ある感じで難しかったのであきらめました。

そこで,pytorchを使ったことがあるのでtchを使うことにしました。
tchはpytorchのc++ライブラリをバインドしたクレートです。
tch使ってみて気が付いたことを共有します。

buildできるようにするための設定 (wsl2)

ビルド済みのライブラリを使うために,pytorchをインストールして利用します。
tch 0.13ではpytorch 2.0.0が前提のようなので,2.0.0をpipでinstallしました。

libtorchのビルド済みライブラリが含まれるフォルダをを以下の2つのパスで指定するか,

pytorchのlibを使う場合。

export LIBTORCH_USE_PYTORCH=1
export LD_LIBRARY_PATH=LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.9.6/lib/python3.9/site-packages/torch/lib/:$LD_LIBRARY_PATH

または自分でビルドしたものを使う場合

export LIBTORCH=$HOME/.pyenv/versions/3.9.6/lib/python3.9/site-packages/torch
export LD_LIBRARY_PATH=LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.9.6/lib/python3.9/site-packages/torch/lib/:$LD_LIBRARY_PATH

macの場合はLD_LIBRARY_PATHの代わりにDYLD_LIBRARY_PATHにtorchのlibフォルダ設定する

ビルド時にlibtorchをDLする場合

featuresに以下を設定すればDLしてtarget以下のフォルダにlibtorchが格納されていました。
こちらの場合はPATHを設定しなくても大丈夫そう。

tch = {version="0.13.0"}
torch-sys = {version="0.13.0", features=["download-libtorch"]}

exampleのコード

example実装はtch-rsのリポジトリに大量にあるので参考にできる。
https://github.com/LaurentMazare/tch-rs/tree/main/examples

stable-diffusionをtch-rsで実装したリポジトリもあり,すぐに試すことができました。
https://github.com/LaurentMazare/diffusers-rs

実装に関するメモ

モデルのパラメータ管理

モデルの重みはVarStoreで管理されている。

let vs = nn::VarStore::new(Device::cuda_if_available());

モデルを構成するパラメータは Path => TensorのMapで管理されており,Layerを初期化するときの引数にPathを指定する。
以下のlinearの場合,重みWbがそれぞれVarStoreに登録され,layer1.weight, layer1.biasのように.繋ぎで下の階層に重みが登録される。

let vs = &vs.root();
nn::linear(
            vs / "layer1",
            IMAGE_DIM,
            HIDDEN_NODES,
            Default::default(),
)

パラメータのfreeze

ネットワークの一部のパラメータをfreezeして,学習ステップ時に学習させたくない場合はよくあると思う。
そのような場合,VarStoreを分ける方法と,Path(Key)で該当するTensorに対してset_requires_grad(false)する方法があると思います。

前者の場合,VarStoreに対してfreezeを行えばOKです。

vs.freeze();

後者の場合,少々面倒ですが以下のようにすることで,一部のfreezeができそうです。
Pathは.繋ぎなので,部分一致したKeyのTensorだけset_requires_grad(false)を処理します。

pub fn freeze_encoder(vs: &mut nn::VarStore) {
let varis = vs.variables();
let encoder_keys = [ENCODER];
println!("freeze encoder");
let mut keys2: Vec<&str> = vec![];
// collect sub keys
for k in &encoder_keys {
    let sub_keys = varis
	.keys()
	.into_iter()
	.filter(|&x| x.starts_with(k))
	.collect::<Vec<&String>>();
    if sub_keys.len() == 0 {
	panic!("{} is not found", k);
    }
    for sub_key in sub_keys {
	keys2.push(sub_key);
    }
}
println!("encoder param keys: {:?}", keys2);
for key in keys2 {
    let v = varis.get(key);
    if let Some(v) = v {
	let _v = v.set_requires_grad(false);
    }
}
}

v1系のpytorch,libtorchを使う場合

現在のv0.13はpytorch v2.0.0だけを許容している。
cuda versionの関係などでv1系のpytorchを使いたい場合は過去バージョンを使うしかないようだ。

https://crates.io/crates/tch

おわり

研究で使っていたDeepLearningのコードをtchで書き直してみたのですが,なかなか書きやすいです。Rustなのでexampleやtestを使ったデバッグもしやすいと感じました。
運用方法としては,モデル定義したクレートを作成して,examplesにtrain用のコードを書いてパラメータを変えて学習を試すというようなことをしています。

Discussion