Open6
NVIDIA-AI-IOT stereoDNN のエクスポート
- TensorFlow v2.7.0
- model_builder.py
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Full license terms provided in LICENSE.md file.
"""
Generates TensorRT C++ API code from TensorFlow model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import warnings
# Ignore 'FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated' warning.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, module='h5py')
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
import tensorrt_model_builder
import model_nvsmall
import model_resnet18
import model_resnet18_2D
import os
import sys
def check_model_type(src):
supported_types = ['nvsmall', 'resnet18', 'resnet18_2D']
if src in supported_types:
return src
else:
raise argparse.ArgumentTypeError('Invalid model type {}. Supported: {}'.format(src, ', '.join(supported_types)))
def check_data_type(src):
if src == 'fp32' or src == 'fp16':
return src
else:
raise argparse.ArgumentTypeError('Invalid data type {}. Supported: fp32, fp16'.format(src))
parser = argparse.ArgumentParser(description='Stereo DNN TensorRT C++ code generator')
# parser.add_argument('--model_type', type=check_model_type, help='model type, currently supported: nvsmall', required=True)
# parser.add_argument('--net_name', type=str, help='network name to use in C++ code generation', required=True)
parser.add_argument('--checkpoint_path', type=str, help='path to checkpoint file (without extension)', required=True)
# parser.add_argument('--weights_file', type=str, help='path to generated weights file', required=True)
# parser.add_argument('--cpp_file', type=str, help='path to generated TensorRT C++ model file', required=True)
# parser.add_argument('--data_type', type=check_data_type, help='model data type, supported: fp32, fp16', default='fp32')
args = parser.parse_args()
def read_model(model_path, session):
print('Reading model...')
saver = tf.train.import_meta_graph(model_path + '.meta', clear_devices=True)
print('Loaded graph meta.')
saver.restore(session, model_path)
print('Loaded weights.')
print('Done reading model.')
basename_without_ext = os.path.splitext(os.path.basename(model_path))[0]
tf.train.write_graph(session.graph_def, '.', f'test.pb', as_text=False)
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import graph_util
def freeze_graph(session, output, save_pb_dir='.', save_pb_name='frozen_model.pb', save_pb_as_text=False):
graph = session.graph
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = graph_util.convert_variables_to_constants(session, graphdef_inf, output)
graph_io.write_graph(graphdef_frozen, save_pb_dir, save_pb_name, as_text=save_pb_as_text)
return graphdef_frozen
freeze_graph(
session,
[
'disparities/ExpandDims',
'disparities/ExpandDims_1'
],
save_pb_dir='.',
save_pb_name=f'{basename_without_ext}.pb'
)
sys.exit(0)
def main():
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
model = read_model(args.checkpoint_path, sess)
with open(args.weights_file, 'wb') as weights_w:
with open(args.cpp_file, 'w') as cpp_w:
builder = tensorrt_model_builder.TrtModelBuilder(model, args.net_name, cpp_w, weights_w, args.data_type)
if args.model_type == 'nvsmall':
model_nvsmall.create(builder)
elif args.model_type == 'resnet18':
model_resnet18.create(builder)
elif args.model_type == 'resnet18_2D':
model_resnet18_2D.create(builder)
else:
# Should never happen, yeah.
assert False, 'Not supported.'
print('Done.')
if __name__ == '__main__':
main()