Pytorchのlibtorchをrustで使うtch-rsに関するメモ
はじめに
Rust言語がマイブームになっているのでDeep Learningの学習/推論もやりたくなりました。
Deep Learning系のRust用crateはいくつかあるようです。
- tch: https://github.com/LaurentMazare/tch-rs
- burn: https://github.com/burn-rs/burn/tree/main
- tensorflow: https://github.com/tensorflow/rust
burnはなかなか面白そうで,backendでtchやwgpuやndarrayを使うことを選択できるみたいです。burnは少し試してみましたがTensorの操作などに一癖ある感じで難しかったのであきらめました。
そこで,pytorchを使ったことがあるのでtchを使うことにしました。
tchはpytorchのc++ライブラリをバインドしたクレートです。
tch使ってみて気が付いたことを共有します。
buildできるようにするための設定 (wsl2)
ビルド済みのライブラリを使うために,pytorchをインストールして利用します。
tch 0.13ではpytorch 2.0.0が前提のようなので,2.0.0をpipでinstallしました。
libtorchのビルド済みライブラリが含まれるフォルダをを以下の2つのパスで指定するか,
pytorchのlibを使う場合。
export LIBTORCH_USE_PYTORCH=1
export LD_LIBRARY_PATH=LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.9.6/lib/python3.9/site-packages/torch/lib/:$LD_LIBRARY_PATH
または自分でビルドしたものを使う場合
export LIBTORCH=$HOME/.pyenv/versions/3.9.6/lib/python3.9/site-packages/torch
export LD_LIBRARY_PATH=LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.9.6/lib/python3.9/site-packages/torch/lib/:$LD_LIBRARY_PATH
macの場合はLD_LIBRARY_PATHの代わりにDYLD_LIBRARY_PATHにtorchのlibフォルダ設定する
ビルド時にlibtorchをDLする場合
featuresに以下を設定すればDLしてtarget以下のフォルダにlibtorchが格納されていました。
こちらの場合はPATHを設定しなくても大丈夫そう。
tch = {version="0.13.0"}
torch-sys = {version="0.13.0", features=["download-libtorch"]}
exampleのコード
example実装はtch-rsのリポジトリに大量にあるので参考にできる。
stable-diffusionをtch-rsで実装したリポジトリもあり,すぐに試すことができました。
実装に関するメモ
モデルのパラメータ管理
モデルの重みはVarStoreで管理されている。
let vs = nn::VarStore::new(Device::cuda_if_available());
モデルを構成するパラメータは Path => TensorのMapで管理されており,Layerを初期化するときの引数にPathを指定する。
以下のlinearの場合,重みlayer1.weight
, layer1.bias
のように.
繋ぎで下の階層に重みが登録される。
let vs = &vs.root();
nn::linear(
vs / "layer1",
IMAGE_DIM,
HIDDEN_NODES,
Default::default(),
)
パラメータのfreeze
ネットワークの一部のパラメータをfreezeして,学習ステップ時に学習させたくない場合はよくあると思う。
そのような場合,VarStoreを分ける方法と,Path(Key)で該当するTensorに対してset_requires_grad(false)
する方法があると思います。
前者の場合,VarStore
に対してfreeze
を行えばOKです。
vs.freeze();
後者の場合,少々面倒ですが以下のようにすることで,一部のfreezeができそうです。
Pathは.
繋ぎなので,部分一致したKeyのTensorだけset_requires_grad(false)
を処理します。
pub fn freeze_encoder(vs: &mut nn::VarStore) {
let varis = vs.variables();
let encoder_keys = [ENCODER];
println!("freeze encoder");
let mut keys2: Vec<&str> = vec![];
// collect sub keys
for k in &encoder_keys {
let sub_keys = varis
.keys()
.into_iter()
.filter(|&x| x.starts_with(k))
.collect::<Vec<&String>>();
if sub_keys.len() == 0 {
panic!("{} is not found", k);
}
for sub_key in sub_keys {
keys2.push(sub_key);
}
}
println!("encoder param keys: {:?}", keys2);
for key in keys2 {
let v = varis.get(key);
if let Some(v) = v {
let _v = v.set_requires_grad(false);
}
}
}
v1系のpytorch,libtorchを使う場合
現在のv0.13はpytorch v2.0.0だけを許容している。
cuda versionの関係などでv1系のpytorchを使いたい場合は過去バージョンを使うしかないようだ。
おわり
研究で使っていたDeepLearningのコードをtchで書き直してみたのですが,なかなか書きやすいです。Rustなのでexampleやtestを使ったデバッグもしやすいと感じました。
運用方法としては,モデル定義したクレートを作成して,examplesにtrain用のコードを書いてパラメータを変えて学習を試すというようなことをしています。
Discussion