TensorFlow 2.xでのRustとPython
TensorFlow 2.xでのRustとPythonの相互運用について説明します。KerasのモデルをRustで利用する方法についても説明します。
ここではGraphを直接エクスポートする方法を説明していすが、こちらではSavedModelのフォーマットの利用例を説明しています。
前置き
ちょうど最近TensorFlow 2を使っており、Twitterでこれを見かけました。
さっそく動かしてみたところ、確かに動いたのですがPython側のモデル定義のスクリプトがTensorFlow 1.xのままでした。特にKerasのモデルを利用する例が載っていませんでした。
TensorFlow 2.xで何が変わった?
TensorFlowは1.xではdefine-and-runと呼ばれていたように、計算グラフを定義し、そのグラフをセッションで実行する方式を採用していました。この方式はTensorFlow 2.xでも変わっていませんが、PythonのAPIはこれらの低水準な実装詳細を隠すようにすると同時に、PyTorchなどで見られたdefine-by-runと呼ばれるスタイルに様変わりしました。
この変更のおかげで1.xに比べるとPython側からはかなり使いやすくなりましたが、学習したモデルを他の言語から利用する場合に戸惑うことが多くなったように思います。これは、ひとえに他言語からの利用がC-APIを通して行われるためであり、このC-APIはTensorFlow 1.xと同様に明示的にグラフやセッションを管理しなければならないままだからだと思われます。
この記事では、Pythonの高レベルAPIで分かりにくくなったTensorFlow 2.xの相互運用について説明します。C-APIを利用する例はRustを例にしますが、他の言語でもここで説明する情報が役に立つはずです。
なお、RustのTensorFlowバインディングはつい先日リリースされたv0.17.0にて、バックエンドが1.15から2.5にアップデートされました。これでPythonとのバージョン不一致による互換性の問題は起こりにくい状況になったかと思います。
Kerasのモデルをグラフ形式で保存する。@Python側
自前で学習したモデルでももちろんいいのですが、ここではMobileNetV3Largeにあるような学習済みのモデルを利用してみます。
TensorFlow (Keras)のモデルの保存形式にはKerasとTF(SavedModel)2種類がありますが、今回利用するのはこのどちらでもありません。一応と言ってはなんですが、TF(SavedModel)のフォーマットもC-APIから直接利用できるのですが、ちょっと手間がかかるので次回以降で紹介します。
大まかな手順としては
- Kerasのモデルのインスタンスを作成する
-
tf.function
でラップして具象関数(Concrete function)を作成する - モデルを固定する(frozen modelに変換する)
- 計算グラフを保存する
となります。
1. Kerasのモデルのインスタンスを作成する。
ここでは、例としてMovileNetV3Largeを使っていますが、任意のKerasのモデルに置き換えても大丈夫です。
import tensorflow as tf
# default input shape 224x224x3
model = tf.keras.applications.MobileNetV3Large()
tf.function
でラップして具象関数(Concrete function)を作成する
2. tf.function
はTensorFlow 2で新たに導入されたデコレータであり、実行の効率と処理の柔軟性を両立させる仕組みとして機能しています。Pythonから使う時にはとても便利になったと思うのですが、TensorFlow特有の計算グラフやSession周りがすべて隠蔽されてしまっています。
そこで、具体的な計算グラフを取得する方法として、.get_concrete_function()
というメソッドを使います。この引数に渡しているx = tf.TensorSpec(model.input_shape, tf.float32, name="x")
はTensorFlow 1.xで見られたplaceholderに相当します。
x = tf.TensorSpec(model.input_shape, tf.float32, name="x")
concrete_function = tf.function(lambda x: model(x)).get_concrete_function(x)
この関数を呼び出せることを確認しておきます。この場合、concrete_function
の使い方は元のKerasのモデルと同じです。
buf = tf.io.read_file("examples/zenn/sample.png")
img = tf.image.decode_png(buf)
sample = tf.cast(img[tf.newaxis, :, :, :], tf.float32)
pred = concrete_function(sample)
エラーがでなければひとまず大丈夫でしょう。
3. モデルを固定する(frozen modelに変換する)
先ほどのconcrete_function
は、計算グラフの内部に変数(tf.Variable
)が残っているため、Sessionが終了すると変数の情報が失われてしまいます。そこで、モデル内部の変数をすべて定数に置き換えていわゆるfrozen graph
を作成します。
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)
# now all variables are converted to constants.
# if this step is omitted, dumped graph does not include trained weights
frozen_model = convert_variables_to_constants_v2(concrete_function)
print(f"{frozen_model.inputs=}")
print(f"{frozen_model.outputs=}")
inputとoutputのノード情報は後でRustから呼び出すときに使います。このモデルでは、入力が"x"
という名前の4次元のテンソルで、出力が"Identity"
という名前の2次元のテンソルであることが分かります。
frozen_model.inputs=[<tf.Tensor 'x:0' shape=(None, None, None, 3) dtype=float32>]
frozen_model.outputs=[<tf.Tensor 'Identity:0' shape=(None, 1000) dtype=float32>]
4. 計算グラフを保存する
最後にfrozen_model
の計算グラフをProtocol buffer形式で書き出します。
directory = "examples/zenn"
tf.io.write_graph(frozen_model.graph, directory, "mobilenetv3large.pb", as_text=False)
以上でモデルの準備はおしまいです。最初のモデル定義とTensorSpecのところを適宜書き換えることで、他のモデルでも同様の手順で書き出せます。
Rustから呼び出す
ここまで準備ができれば、あとはtensorflow/rust
のexapmplesを見ながら進められると思います。やることは、
- 入力テンソルを作成する
- グラフを読み込む
- グラフを実行する
1. 入力テンソルを作成する
このモデルに食わせる入力データを準備しておきます。画像は224x224に予めクリッピングされてあるものを用いています。また、MobileNetV3は入力は0-255のfloat型なので、それに合わせてやります。なお、Tensorの型はRustの型推論でf32に自動で決めてくれますが、自分で指定したい場合は<Tensor<f32>>::new()
もしくはTensor::<f32>::new()
と書くこともできます。
画像の読み込みはImage crateを使っています。また、テンソルへのアクセスはVecへのアクセスのようですが、Column-major orderの4次元テンソルであることに注意してください。
// Create input variables for our addition
let mut x = Tensor::new(&[1, 224, 224, 3]);
let img = ImageReader::open("examples/zenn/sample.png")?.decode()?;
for (i, (_, _, pixel)) in img.pixels().enumerate() {
x[3 * i] = pixel.0[0] as f32;
x[3 * i + 1] = pixel.0[1] as f32;
x[3 * i + 2] = pixel.0[2] as f32;
}
2. グラフを読み込む
// Load the computation graph defined by zenn.py.
let mut graph = Graph::new();
let mut proto = Vec::new();
File::open(filename)?.read_to_end(&mut proto)?;
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
let session = Session::new(&SessionOptions::new(), &graph)?;
3. グラフを実行する
このセッションを使って、グラフを実行します。
先ほどのPythonで取得した情報を使います。入力ノード"x"
に入力テンソルxを渡して、出力ノード"Identity"
から計算結果を取得します。
ここで、inputとoutputのノード情報は後でRustから呼び出すときに使います。このモデルでは、入力が
"x"
という名前の4次元のテンソルで、出力が"Identity"
という名前の2次元のテンソルであることが分かります。
// Run the graph.
let mut args = SessionRunArgs::new();
args.add_feed(&graph.operation_by_name_required("x")?, 0, &x);
let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0);
session.run(&mut args)?;
4. 結果を確認する
"Identity"
から取得したテンソルを回収します。Vecで受けてもいいのですが、ndarrayへの変換をオプションでサポートしているので、それを使います。
// Check our results.
let output: Tensor<f32> = args.fetch(output)?;
let res: ndarray::Array<f32, _> = output.into();
println!("{:?}", res);
何やら数値が返ってきました。Tensorはndarrayに変換する関数を適宜してくれてあり、ただのvecで受け取るよりも使い勝手がいいように思います。
使用したコード全体
コード全文を下記に掲載します。また、dskkato/tf2_python_rustにもコードを掲載しています。
Python側
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)
# default input shape 224x224x3
model = tf.keras.applications.MobileNetV3Large()
x = tf.TensorSpec(model.input_shape, tf.float32, name="x")
concrete_function = tf.function(lambda x: model(x)).get_concrete_function(x)
# now all variables are converted to constants.
# if this step is omitted, dumped graph does not include trained weights
frozen_model = convert_variables_to_constants_v2(concrete_function)
directory = "examples/zenn"
tf.io.write_graph(frozen_model.graph, directory, "mobilenetv3large.pb", as_text=False)
Rust側
[package]
name = "tf2_python_rust"
version = "0.1.0"
edition = "2018"
[dependencies]
image = "0.23.14"
tensorflow = {version="0.17.0", features=["ndarray"]}
use std::error::Error;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::result::Result;
use tensorflow::Code;
use tensorflow::Graph;
use tensorflow::ImportGraphDefOptions;
use tensorflow::Session;
use tensorflow::SessionOptions;
use tensorflow::SessionRunArgs;
use tensorflow::Status;
use tensorflow::Tensor;
use ndarray;
use image::io::Reader as ImageReader;
use image::GenericImageView;
fn main() -> Result<(), Box<dyn Error>> {
let filename = "examples/zenn/mobilenetv3large.pb";
if !Path::new(filename).exists() {
return Err(Box::new(
Status::new_set(
Code::NotFound,
&format!(
"Run 'python examples/zenn/zenn.py' to generate {} \
and try again.",
filename
),
)
.unwrap(),
));
}
// Create input variables for our addition
let mut x = Tensor::new(&[1, 224, 224, 3]);
let img = ImageReader::open("examples/zenn/sample.png")?.decode()?;
for (i, (_, _, pixel)) in img.pixels().enumerate() {
x[3 * i] = pixel.0[0] as f32;
x[3 * i + 1] = pixel.0[1] as f32;
x[3 * i + 2] = pixel.0[2] as f32;
}
// Load the computation graph defined by addition.py.
let mut graph = Graph::new();
let mut proto = Vec::new();
File::open(filename)?.read_to_end(&mut proto)?;
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
let session = Session::new(&SessionOptions::new(), &graph)?;
// Run the graph.
let mut args = SessionRunArgs::new();
args.add_feed(&graph.operation_by_name_required("x")?, 0, &x);
let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0);
session.run(&mut args)?;
// Check our results.
let output: Tensor<f32> = args.fetch(output)?;
let res: ndarray::Array<f32, _> = output.into();
println!("{:?}", res);
Ok(())
}
後書き
frozen_graphの形式は、C-APIから利用する分には便利なのですが、学習しなおしたりモデルを更新できません。その点を見るとSavedModelというフォーマットのほうが便利なのですが、C-APIから利用するための準備はここで説明したものよりも煩雑になります。こちらはまた後日書き加えようと思います。
また、この方法をmasOS用のTensorFlowをtensorflow-metalのプラグインを使った環境でシリアライズするとうまくエクスポートできませんでした。この場合、SavedModelのフォーマットのまま利用することをおススメします。
参考
-
https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/
- ここで説明した内容と基本的に同じです
-
https://qiita.com/karaage0703/items/5946e41b6043795c1b30
- 変換方法は違いますが、その他の
.pb
の中身を調べるのには参考になりそうです。
- 変換方法は違いますが、その他の
Discussion