📝

TensorFlow/RustのEager APIを実装しました。

2022/04/17に公開

TensorFlowのRust bindingの開発をしばらくやっていました。ことの経緯はこちらをご覧ください。

https://zenn.dev/dskkato/articles/tf2-rust-binding

この度、v0.18.0が無事リリースされました。ここでは、このリリースで使えるようになったeager APIの使い方を紹介します。いろいろなことができると思うのですが、ここでは僕のもともとのモチベーションであった、「image crateを使わずに画像を読み込んでリサイズする」というユースケースを例にします。

依存ライブラリの追加

Cargo.tomlにtensorflowを追加します。featuresにeagerを追加するのを忘れないでください。

Cargo.toml
[dependencies.tensorflow]
version = "0.18.0"
features = ["eager"]

また、main.rsのはじめに使うモジュールなどを書いて置きます。

main.rs
use tensorflow as tf;
use tf::eager::{self, raw_ops, ReadonlyTensor, TensorHandle, ToTensorHandle};
use tf::Tensor;

eagerモードの実行コンテキストを作る

eager APIを使う際には、通常のTensorFlowのSession以外にeager API用の実行コンテキストを作る必要があります。このコンテキストではCPU/GPUや非同期での実行、使用するデバイスなどの設定ができるのですが、簡単のためにデフォルトのまま使うことにします。

main.rs
// eager API実行のコンテキストを作る。GPUの使用や、デバイスを指定することができる。
let opts = eager::ContextOptions::new();
let ctx = eager::Context::new(opts)?;

画像を読み込む

eager APIを使うときは、通常のTensorではなくTensorHandleを使います。これらはTensorFlowの内部ではデータの実体は同じなのですが、前者はGraph API用、後者はeager API用のようです。現在の実装では、少し楽をするためにToTensorHandleというtraitを作っています。以下では、&strから(Tensor<String>を明示的に作らずに)直接TensorHandleを作っています。また、各操作(内部的に言うと各Op)それぞれに実行コンテキストを与える必要があるため、第一引数はどれも&ctrを渡しています。

main.rs
// eager APIを使って画像を読み込み
let fname: TensorHandle = "sample/macaque.jpg".to_handle(&ctx)?;
let buf: TensorHandle = raw_ops::read_file(&ctx, &fname)?;
let img: TensorHandle = raw_ops::decode_image(&ctx, &buf)?;

同様のことをPythonでやるとこんな感じです。モジュールの違いがありますが、似たような記述ができるようになっていますね。

main.py
# 画像を読み込む
fname = "sample/macaque.jpg"
buf = tf.io.read_file(fname)
img = tf.image.decode_image(buf)

ここで一つ注意なのですが、RustのTensorHandleは型情報にデータタイプなどを持つことができておらず、不正な操作はすべて実行時にしか分かりません。

[0, 1]に正規化して、バッチ化する

こちらは先にPythonのコードを示します。画像の読み込みはuint8で0-255のスケールで返ってきます。

main.py
# floatに変換して[0,1]に正規化
img = tf.cast(img, tf.float32)
img = img / 255.0

# バッチサイズ1に変換
batch = tf.expand_dims(img, 0)

さて、この部分をRustで書くと少し変わってきます。

main.rs
// 後で[0, 1]に正規化するために、floatにキャストしておく
let cast2float = raw_ops::Cast::new().DstT(tf::DataType::Float);
let img = cast2float.call(&ctx, &img)?;
assert!(img.data_type() == tf::DataType::Float);

// [0, 1]に正規化する。255.0とすると、型の不一致でエラーになる。
let img = raw_ops::div(&ctx, &img, &255.0f32)?;

// HWC -> NHWC に変換する
let batch = raw_ops::expand_dims(&ctx, &img, &0)?;

まず、castの部分ですが、Rustではデフォルト引数を使えない関係でオプション引数が必要な操作はすべてビルダーから作成するようにしています。Opによっては見慣れないオプションがあるかもしれませんが、これは通常の整備されたPythonで使うAPIとは異なり、raw_opsの定義をそのままむき出しにしていることによるものです。また、割り算くらいの演算子はオーバーロードしてもいいかもしれませんが、そこにはまだ手を出していないです。

リサイズする

Pythonだと何のことはないです。最後のantialiasはTF 2.xから使えるようになったオプションです。

main.py
resized = tf.image.resize(batch, [224, 224], "bilinear", antialias=True)

画像を縮小する際にAntialiasのオプションを使わなかった場合の問題は以下をご覧ください。

https://twitter.com/yoya/status/1412980660554240002?s=20&t=R491uBY81NExWwQ-e37UoA

このresizeをナイーブに実装すると、ResizeBilinearを使えば良さそうですが、これにはantialiasのオプションがありません。実は、TF 2.xだと、ScaleAndTranslateという別のOpで実装されています。さて、これを使ってリサイズしてみます。

main.rs
// [224, 224, 3]にリサイズする。
// ここではantialiasを有効にするために、v2のAPIを使う。
let resize_bilinear = raw_ops::ScaleAndTranslate::new()
    .kernel_type("triangle") // bilinearのオプションに相当
    .antialias(true);
let scale = [224.0 / height as f32, 224.0 / width as f32];
let resized = resize_bilinear.call(&ctx, &batch, &[224, 224], &scale, &[0f32, 0f32])?;

結果を確認する

Pythonの結果を正として、Rustの結果と一致するかを確認する。

main.py
# 1ピクセル目の値を確認する。
print(f"{resized[0, 0, 0, :3]}")
# [0.29298395 0.35878524 0.4291904 ]

RustのほうでTensorHandleの内部のバッファにアクセスするには、Tensorに戻す必要があります。ReadonlyTensorを経由しているのは、Rust側の操作でpointer aliasingによるUBの発生を避けるためにやむを得ず導入したものです。

main.rs
// Tensorの中身にアクセスできるように、TensorHandleからTensorに戻す
// 今の実装では、ReadonlyTensorを経由してTensorに戻す必要がある。
let t: ReadonlyTensor<f32> = resized.resolve()?;
// let t: Tensor<f32> = unsafe { t.into_tensor() };

// resize後の1つ目のピクセルについて、
// Pythonで計算した結果と比較する
assert!((t[0] - 0.29298395).abs() < 1e-5);
assert!((t[1] - 0.35878524).abs() < 1e-5);
assert!((t[2] - 0.42919040).abs() < 1e-5);

1ピクセル分しか確認してませんが、私の手元ではきちんと結果が一致していそうでした。

参考

以上の比較コードはこちらに置きました。

https://github.com/dskkato/tf-rust-eager-sample

また、もともとtch -> tfに置き換えていたサンプルコードも今回のアップデートに合わせて更新しました。

https://github.com/dskkato/rust-machine-learning-api-example

やっとマージされたばかりなので、不具合等があるかもしれません。こちらのIssuesに上げていただけると私が対応すると思います。

https://github.com/tensorflow/rust

Discussion