🐍

Rust版PyTorchでDCGANチュートリアル

3 min read

PyTorchの非公式Rustバインディングのtch-rsでDCGANのチュートリアルを動かしてみた。
導入方法は過去記事参照(普通にCargo.tomlに書けばいい)
win10のローカル環境で動かしました。時間があれば、colabでも動かして記事を書くかも。

やったこと

書いたコード

以下を参考に、ただのDCGANを学習させてみた。

設定を外出しする

普段Pythonで学習を回す時など、Configクラスを作って、jsonから設定を読み込んでいるので、Rustでもやってみる。

serde_jsonを使えば、jsonファイルをよしなに変換して、構造体を生成してくれる。
serde_jsonの使い方はここを参考

パスを受け取る<hoge>rootの型がBox<String>になっているのは、Trainer構造体にconfigをまるまる持たせるときに、ただのString所有権周りでごたついたから。(もう少し上手くやりたい)

Configクラス
#[derive(Serialize, Deserialize, Debug)]
struct TrainerConfig {
    dataroot: Box<String>,
    modelroot: Box<String>,
    saveroot: Box<String>,
    latent_dim: i64,
    img_size: i64,
    beta1: f64,
    lr: f64,
    batch_size:i64,
    use_gpu:bool,
    max_steps:u32,
    logging_steps:u32,
    eval_steps:u32,
    save_steps:u32,
}

impl TrainerConfig {
    fn new<T: AsRef<Path>>(config_path:T) -> TrainerConfig {
        let file = File::open(config_path).unwrap();
        let reader = BufReader::new(file);

        serde_json::from_reader(reader).unwrap()
    }
}

jitでモデル定義

モデル定義はPyTorchで書いて、jitを使ってモデルデータを書き出してRustに読み込ませた。
Rustでもモデル定義は出来るが、その場合のモデルの保存方法がtch-rsのサンプルになかった。
jitなら、読み込みも書き込みもtch-rsの方で関数が定義されている。

jitについてはこの記事を参考

jitの書き出し方法は、traceとscriptがある。
traceは実際にデータを流して、その時の計算グラフを元に出力する。
scriptはscript用のクラスを継承してコードを元に出力する。
今回はtraceを使った。

あと、tch-rsは jitで読み込むのはただのModuleじゃなくてCModule構造体
学習を回すならTrainableCModuleを使うこと

学習を回す

データ周り

データセットはご存知セレブのやつ
データの読み込みは、tch-rsの方ですでに、tch::vision::image::load_dir関数が用意されていて、指定したディレクトリにある画像を全部読んで、1つのTensorにして返してくれる。
他にも、mnistやcifar-10用の読み込み関数があるっぽい。
画像の書き出し関数もあった

パラメータの更新

あとは普通にサンプルコードをコピペして学習を回すだけだが、ganだと、生成モデル・判別モデルの片方のパラメータだけ更新したい。
tch-rsでは、VarStoreにモデルやoptimizerを載せるが、複数モデルを学習させる場合、それぞれ別のVarStoreに載せる必要がある。
VarStore単位で、freeze(), unfreeze()を呼びだしてパラメータを更新するかどうか制御するっぽい。

ただの感想

forwardの部分がRustらしくメソッドチェーンしていた(Pythonと違って)。
割とPythonのときの感覚で書いたので、あんまりRustっぽくは無い気がする。
Rustだとエラーメッセージを書くのが簡単。
別にRustだからといって、コンパイル時にTensorの大きさの不一致を教えてくれない。
使えるけど、Pythonでゴリゴリ書けてる人は別にRustに移動する必要ない気がする。
(個人的にRust楽しいと思うので、今後ライブラリが充実してくればRust使いたい)

Discussion

ログインするとコメントできます