🐍

Rust版PyTorch(tch-rs)を試してみる

2021/03/17に公開

PyTorchの非公式Rustバインディングのtch-rsの紹介
とりあえず、手元のwindows10環境で動かしてみた。

ばんくしさんの記事で、Rustでも機械学習関連の開発が進んでいると知り、とりあえずRust版PyTorchのtch-rsをインストールしてみた。

やってみた記事としては、この記事の方が分かりやすいかも。

インストール

通常のクレートと同じ、cargo newして生成されるCargo.tomlの依存関係のところに追加するだけ、ビルドの際に自動で必要なものをダウンロードしてビルドしてくれる。

gpuを使う方法

デフォルト設定では、CPU用のlibtorchしか使ってくれない。
現状、GPUを使いたければ 手動でlibtorchをダウンロードして環境変数にパスを通す必要がある。
私の環境(win10)だと最新版の1.8.0はビルドが失敗したので、1.7.1を使用したらコンパイルが通った。

MNIST 動かす

tch-rsのサンプルから、MNISTを学習させるやつをやってみる。
サンプルコードだとCPU上で学習させているが、Device::CPUとあるところをDevice::cuda_if_available()にして、各データについてhoge.to(device)すればGPU上で動かせる。

なお、tch-rsのサンプルは実行結果を返す際に外部ライブラリの anyhowのResult型を使っているので、依存関係に追加することを忘れずに。
あと、MNISTデータセットも自動でダウンロードしてくれない ので、手動で指定されたディレクトリに展開すること。コンパイルすればどうファイルを置けばいいのかエラーで示してくれるのでそのとおりにする。

モデルを保存する関数は特別用意されてなさそう。
(別途、structを保存するやつを使う必要があるのかな?)

とりあえずMNISTを学習させるサンプル
extern crate anyhow;
extern crate tch;
use anyhow::Result;
use tch::{nn, nn::Module, nn::OptimizerConfig, Device};

const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;

fn net(vs: &nn::Path) -> impl Module {
    nn::seq()
        .add(nn::linear(
            vs / "layer1",
            IMAGE_DIM,
            HIDDEN_NODES,
            Default::default(),
        ))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}

pub fn run() -> Result<()> {
    let m = tch::vision::mnist::load_dir("data")?;
    let device = Device::cuda_if_available();
    let vs = nn::VarStore::new(device);
    let net = net(&vs.root());
    let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
    for epoch in 1..200 {
        let loss = net
            .forward(&m.train_images.to(device))
            .cross_entropy_for_logits(&m.train_labels.to(device));
        opt.backward_step(&loss);
        let test_accuracy = net
            .forward(&m.test_images.to(device))
            .accuracy_for_logits(&m.test_labels.to(device));
        println!(
            "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
            epoch,
            f64::from(&loss),
            1.   * f64::from(&test_accuracy),
        );
    }
    Ok(())
}

jitでモデルをPythonで書ける

PyTorchにはjitと言って、Pythonで定義したモデルをコンパイルして他の言語のPyTorchから利用できる仕組みがある。

サンプルコードを見る限り、tch-rsでもそれは使えそう。
しかも、ロードとセーブ機能もあるっぽい。

multi gpuは難しそう

ドキュメントを見た感じmulti gpuをするための関数が見当たらなかった。
別々のGPUに独立に学習を回すのはできるみたいだけど......

Discussion