📘
tensorflow v2でfrozen_inference_graph.pbをsaved_modelに変換
覚書なので、全くもって完ぺきではありません。なのでより良い方法がありましたら、コメントをお願いします。
参考
import tensorflow as tf
print(tf.__version__) #2.13.0
# Path to frozen graph file
frozen_graph_path = 'path_to_frozen_inference_graph.pb'
# Load frozen graph
with tf.io.gfile.GFile(frozen_graph_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
# Create graph from graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
ここでfrozen_graphを読み込む。frozen_graph_pathを適宜書き直す。
# Path to SavedModel directory
saved_model_path = 'saved_model_dir_name'
# Create SavedModel builder
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(saved_model_path)
if tf.executing_eagerly():
tf.compat.v1.disable_eager_execution()
# Define input and output signatures
# inputs,outputsはfrozen_graphの名前に合わせる。
inputs = {
"image_tensor": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("image_tensor:0"))
}
outputs = {
"num_detections": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("num_detections:0")),
"detection_boxes": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("detection_boxes:0")),
"detection_scores": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("detection_scores:0")),
"detection_classes": tf.compat.v1.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("detection_classes:0")),
}
# Define signature definition
signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME
)
# Add graph to SavedModel
builder.add_meta_graph_and_variables(
sess=tf.compat.v1.Session(),
tags=[tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
}
)
# Save SavedModel
builder.save()
ここでSavedModelを保存するDirとinputs,outputsを定義する。
frozen_graphのinputs,outputsを確認するいい方法は知らないので、知っている方はコメント頂けると幸いです。
Discussion