📘

tensorflow v2でfrozen_inference_graph.pbをsaved_modelに変換

2023/08/31に公開

覚書なので、全くもって完ぺきではありません。なのでより良い方法がありましたら、コメントをお願いします。

参考
https://saturncloud.io/blog/how-to-convert-a-tensorflow-frozen-graph-to-savedmodel/

https://stackoverflow.com/questions/44329185/convert-a-graph-proto-pb-pbtxt-to-a-savedmodel-for-use-in-tensorflow-serving-o/44329200#44329200

import tensorflow as tf
print(tf.__version__) #2.13.0

# Path to frozen graph file
frozen_graph_path = 'path_to_frozen_inference_graph.pb'

# Load frozen graph
with tf.io.gfile.GFile(frozen_graph_path, "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

# Create graph from graph_def
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="")

ここでfrozen_graphを読み込む。frozen_graph_pathを適宜書き直す。

# Path to SavedModel directory
saved_model_path = 'saved_model_dir_name'
# Create SavedModel builder
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(saved_model_path)
if tf.executing_eagerly():
   tf.compat.v1.disable_eager_execution()
# Define input and output signatures
# inputs,outputsはfrozen_graphの名前に合わせる。
inputs = {
    "image_tensor": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("image_tensor:0"))
}

outputs = {
    "num_detections": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("num_detections:0")),
    "detection_boxes": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("detection_boxes:0")),
    "detection_scores": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("detection_scores:0")),
    "detection_classes": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("detection_classes:0")),
}

# Define signature definition
signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME
)

# Add graph to SavedModel
builder.add_meta_graph_and_variables(
    sess=tf.compat.v1.Session(),
    tags=[tf.compat.v1.saved_model.tag_constants.SERVING],
    signature_def_map={
        tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
    }
)

# Save SavedModel
builder.save()

ここでSavedModelを保存するDirとinputs,outputsを定義する。
frozen_graphのinputs,outputsを確認するいい方法は知らないので、知っている方はコメント頂けると幸いです。

Discussion