📝

TensorFlow 2.xでのRustとPython (SavedModel編)

2021/07/15に公開

前回は計算グラフをprobobuf(.pb)形式でエクスポートしてRustから利用する方法を説明しました。この方法はモデルファイルが比較的コンパクトになるのがメリットですが、Pythonから再利用する時や、TFLiteへの変換ができないなど取り回しが悪くなってしまいます。そこで、TensorFlowの標準のフォーマットであるSavedModelの形式をRustから利用する方法を説明します。

SavedModelについてはこちらを参考にしてください。

SavedModelの中にも計算グラフが入っているので、基本的にはグラフを直接読み込みに行く方法と大きな違いはありません。

KerasのモデルをSavedModel形式で保存する。@Python側

下記の3つの作業を行います。1, 2はPythonコードで3はsaved_model_cliというコマンドラインツールを使います。このツールはTensorFlowをインストールすると付いてきます。

  1. Kerasのモデルのインスタンスを作成する。
  2. SavedModel形式で保存する
  3. 書き出したモデルの情報を集める。

1. Kerasのモデルのインスタンスを作成する。

前回と同じものを使いまわしします。

import tensorflow as tf

# default input shape 224x224x3
model = tf.keras.applications.MobileNetV3Large()

2. SavedModel形式で保存する

SavedModel形式で保存するAPIは2つあります。

  • tf.saved_model.save
  • tf.keras.Modelのメソッド

の2つがあります。どちらを使ってもいいのですが、今回ロードしたモデルはKerasのモデルなので後者のModelのAPIをそのまま使う方法で進めてみます。

directory = "examples/zenn_savedmodel"
model.save(directory)

3. 書き出したモデルの情報を集める。

saved_model_cliを使って書き出したモデルファイルの情報を確認しておきます。

SavedModelの中の次の3つを調べます。調べたいのは3つ目のin/outのテンソルの情報です。

  • 3.1 tag
    • デフォルト=>"serve"
  • 3.2 SignatureDefs
    • デフォルト=>"serving_default"
  • 3.3 inputs/outputs tensor_info
    • モデルに依存

3.1 tag情報

saved_model_cli show --dir examples/zenn_savedmodel
...
The given SavedModel contains the following tag-sets:
'serve'

serveというtagがあることが分かります。

3.2 SignatureDefs

saved_model_cli show --dir examples/zenn_savedmodel --tag serve
...
The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "__saved_model_init_op"
SignatureDef key: "serving_default"

SignatureDefが2つ見えますが、今回使うのは"serving_default"のほうだけです。"serving_default"は保存時に"signature_def"を与えなかった場合に設定されるデフォルト値です。

3.3 inputs/outputs tensor_info

最後に先ほどの"serving_default"のin/outの情報を確認します。

saved_model_cli show --dir examples/zenn_savedmodel --tag serve --signature_def serving_default
...
The given SavedModel SignatureDef contains the following input(s):
  inputs['input_1'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, -1, -1, 3)
      name: serving_default_input_1:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['Predictions'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1000)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

ここから、この計算グラフの入力が"input_1"が1つで、出力が"Predictions"が1つであることが分かります。これらの情報を後で使います。

Rust側

1. 入力テンソルを作成する

入力テンソルを作ります。ここは前回と同様です。

// Create input variables for our addition
let mut x = Tensor::new(&[1, 224, 224, 3]);
let img = ImageReader::open("examples/zenn/sample.png")?.decode()?;
for (i, (_, _, pixel)) in img.pixels().enumerate() {
    x[3 * i] = pixel.0[0] as f32;
    x[3 * i + 1] = pixel.0[1] as f32;
    x[3 * i + 2] = pixel.0[2] as f32;
}

2. モデルを読み込む

前回はimport_graph_defという関数を使って計算グラフを読み込みました。

graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;`

SavedModelフォーマットの場合、次のようにSavedModelBundle::loadを使います。

先ほど調べた次の情報を使っています。

  • tag
    • serve
  • SignatureDegs
    • "serving_default"
    • ここではDEFAULT_SERVING_SIGNATURE_DEF_KEYという定数を使っています。
  • "serving_default"のin/out
    • 入力は"input_1"
    • 出力は"Predictions"
// Load the saved model exported by zenn_savedmodel.py.
let mut graph = Graph::new();
let bundle =
    SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
let session = &bundle.session;

// get in/out operations
let signature = bundle
    .meta_graph_def()
    .get_signature(DEFAULT_SERVING_SIGNATURE_DEF_KEY)?;
let x_info = signature.get_input("input_1")?;
let op_x = &graph.operation_by_name_required(&x_info.name().name)?;
let output_info = signature.get_output("Predictions")?;
let op_output = &graph.operation_by_name_required(&output_info.name().name)?;

3. グラフを実行する

SavedModelBundle::loadで初期化したSession、graphを使って計算を実行します。

// Run the graph.
let mut args = SessionRunArgs::new();
args.add_feed(op_x, 0, &x);
let token_output = args.request_fetch(op_output, 0);
session.run(&mut args)?;

4. 結果を確認する

前回と同じですので、説明は特にありません。

// Check our results.
let output: Tensor<f32> = args.fetch(token_output)?;
let res: ndarray::Array<f32, _> = output.into();
println!("{:?}", res);

使用したコード全体

コード全文を下記に掲載します。また、dskkato/tf2_python_rustにもコードを掲載しています。

Python側

examples/zenn_savedmodel/zenn_savedmodel.py
import tensorflow as tf

# default input shape 224x224x3
model = tf.keras.applications.MobileNetV3Large()

directory = "examples/zenn_savedmodel"
model.save(directory)

Rust側

Cargo.toml
[package]
name = "tf2_python_rust"
version = "0.1.0"
edition = "2018"

[dependencies]
image = "0.23.14"
tensorflow = {version="0.17.0", features=["ndarray"]}
examples/zenn_savedmodel.rs
use std::error::Error;
use std::path::Path;
use std::result::Result;
use tensorflow::Code;
use tensorflow::Graph;
use tensorflow::SavedModelBundle;
use tensorflow::SessionOptions;
use tensorflow::SessionRunArgs;
use tensorflow::Status;
use tensorflow::Tensor;
use tensorflow::DEFAULT_SERVING_SIGNATURE_DEF_KEY;

use ndarray;

use image::io::Reader as ImageReader;
use image::GenericImageView;

fn main() -> Result<(), Box<dyn Error>> {
    let export_dir = "examples/zenn_savedmodel";
    if !Path::new(export_dir).exists() {
        return Err(Box::new(
            Status::new_set(
                Code::NotFound,
                &format!(
                    "Run 'python zenn_savedmodel.py' to generate \
                     {} and try again.",
                    export_dir
                ),
            )
            .unwrap(),
        ));
    }

    // Create input variables for our addition
    let mut x = Tensor::new(&[1, 224, 224, 3]);
    let img = ImageReader::open("examples/zenn/sample.png")?.decode()?;
    for (i, (_, _, pixel)) in img.pixels().enumerate() {
        x[3 * i] = pixel.0[0] as f32;
        x[3 * i + 1] = pixel.0[1] as f32;
        x[3 * i + 2] = pixel.0[2] as f32;
    }

    // Load the saved model exported by zenn_savedmodel.py.
    let mut graph = Graph::new();
    let bundle =
        SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
    let session = &bundle.session;

    let signature = bundle
        .meta_graph_def()
        .get_signature(DEFAULT_SERVING_SIGNATURE_DEF_KEY)?;
    let x_info = signature.get_input("input_1")?;
    let op_x = &graph.operation_by_name_required(&x_info.name().name)?;
    let output_info = signature.get_output("Predictions")?;
    let op_output = &graph.operation_by_name_required(&output_info.name().name)?;

    // Run the graph.
    let mut args = SessionRunArgs::new();
    args.add_feed(op_x, 0, &x);
    let token_output = args.request_fetch(op_output, 0);
    session.run(&mut args)?;

    // Check our results.
    let output: Tensor<f32> = args.fetch(token_output)?;
    let res: ndarray::Array<f32, _> = output.into();
    println!("{:?}", res);

    Ok(())
}

参考

Discussion