Rust版PyTorchでDCGANチュートリアル
PyTorchの非公式Rustバインディングのtch-rsでDCGANのチュートリアルを動かしてみた。
導入方法は過去記事参照(普通にCargo.tomlに書けばいい)
win10のローカル環境で動かしました。時間があれば、colabでも動かして記事を書くかも。
やったこと
書いたコード
以下を参考に、ただのDCGANを学習させてみた。
設定を外出しする
普段Pythonで学習を回す時など、Configクラスを作って、jsonから設定を読み込んでいるので、Rustでもやってみる。
serde_jsonを使えば、jsonファイルをよしなに変換して、構造体を生成してくれる。
serde_jsonの使い方はここを参考
パスを受け取る<hoge>root
の型がBox<String>
になっているのは、Trainer構造体にconfigをまるまる持たせるときに、ただのString
所有権周りでごたついたから。(もう少し上手くやりたい)
Configクラス
#[derive(Serialize, Deserialize, Debug)]
struct TrainerConfig {
dataroot: Box<String>,
modelroot: Box<String>,
saveroot: Box<String>,
latent_dim: i64,
img_size: i64,
beta1: f64,
lr: f64,
batch_size:i64,
use_gpu:bool,
max_steps:u32,
logging_steps:u32,
eval_steps:u32,
save_steps:u32,
}
impl TrainerConfig {
fn new<T: AsRef<Path>>(config_path:T) -> TrainerConfig {
let file = File::open(config_path).unwrap();
let reader = BufReader::new(file);
serde_json::from_reader(reader).unwrap()
}
}
jitでモデル定義
モデル定義はPyTorchで書いて、jitを使ってモデルデータを書き出してRustに読み込ませた。
Rustでもモデル定義は出来るが、その場合のモデルの保存方法がtch-rsのサンプルになかった。
jitなら、読み込みも書き込みもtch-rsの方で関数が定義されている。
jitの書き出し方法は、traceとscriptがある。
traceは実際にデータを流して、その時の計算グラフを元に出力する。
scriptはscript用のクラスを継承してコードを元に出力する。
今回はtraceを使った。
あと、tch-rsは jitで読み込むのはただのModuleじゃなくてCModule構造体
学習を回すならTrainableCModuleを使うこと
学習を回す
データ周り
データセットはご存知セレブのやつ
データの読み込みは、tch-rsの方ですでに、tch::vision::image::load_dir
関数が用意されていて、指定したディレクトリにある画像を全部読んで、1つのTensorにして返してくれる。
他にも、mnistやcifar-10用の読み込み関数があるっぽい。
画像の書き出し関数もあった
パラメータの更新
あとは普通にサンプルコードをコピペして学習を回すだけだが、ganだと、生成モデル・判別モデルの片方のパラメータだけ更新したい。
tch-rsでは、VarStore
にモデルやoptimizerを載せるが、複数モデルを学習させる場合、それぞれ別のVarStore
に載せる必要がある。
VarStore
単位で、freeze()
, unfreeze()
を呼びだしてパラメータを更新するかどうか制御するっぽい。
ただの感想
forwardの部分がRustらしくメソッドチェーンしていた(Pythonと違って)。
割とPythonのときの感覚で書いたので、あんまりRustっぽくは無い気がする。
Rustだとエラーメッセージを書くのが簡単。
別にRustだからといって、コンパイル時にTensorの大きさの不一致を教えてくれない。
使えるけど、Pythonでゴリゴリ書けてる人は別にRustに移動する必要ない気がする。
(個人的にRust楽しいと思うので、今後ライブラリが充実してくればRust使いたい)
Discussion