Open6

NVIDIA-AI-IOT stereoDNN のエクスポート

  • TensorFlow v2.7.0
  • model_builder.py
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Full license terms provided in LICENSE.md file.

"""
Generates TensorRT C++ API code from TensorFlow model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

import warnings
# Ignore 'FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated' warning.
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning, module='h5py')
    import tensorflow.compat.v1 as tf
    tf.disable_eager_execution()

import tensorrt_model_builder
import model_nvsmall
import model_resnet18
import model_resnet18_2D

import os
import sys

def check_model_type(src):
    supported_types = ['nvsmall', 'resnet18', 'resnet18_2D']
    if src in supported_types:
        return src
    else:
        raise argparse.ArgumentTypeError('Invalid model type {}. Supported: {}'.format(src, ', '.join(supported_types)))

def check_data_type(src):
    if src == 'fp32' or src == 'fp16':
        return src
    else:
        raise argparse.ArgumentTypeError('Invalid data type {}. Supported: fp32, fp16'.format(src))

parser = argparse.ArgumentParser(description='Stereo DNN TensorRT C++ code generator')

# parser.add_argument('--model_type',      type=check_model_type, help='model type, currently supported: nvsmall', required=True)
# parser.add_argument('--net_name',        type=str, help='network name to use in C++ code generation',  required=True)
parser.add_argument('--checkpoint_path', type=str, help='path to checkpoint file (without extension)', required=True)
# parser.add_argument('--weights_file',    type=str, help='path to generated weights file',              required=True)
# parser.add_argument('--cpp_file',        type=str, help='path to generated TensorRT C++ model file',   required=True)
# parser.add_argument('--data_type',       type=check_data_type, help='model data type, supported: fp32, fp16', default='fp32')

args = parser.parse_args()

def read_model(model_path, session):
    print('Reading model...')
    saver = tf.train.import_meta_graph(model_path + '.meta', clear_devices=True)
    print('Loaded graph meta.')
    saver.restore(session, model_path)
    print('Loaded weights.')
    print('Done reading model.')

    basename_without_ext = os.path.splitext(os.path.basename(model_path))[0]
    tf.train.write_graph(session.graph_def, '.', f'test.pb', as_text=False)

    from tensorflow.python.framework import graph_io
    from tensorflow.python.framework import graph_util
    def freeze_graph(session, output, save_pb_dir='.', save_pb_name='frozen_model.pb', save_pb_as_text=False):
        graph = session.graph
        with graph.as_default():
            graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
            graphdef_frozen = graph_util.convert_variables_to_constants(session, graphdef_inf, output)
            graph_io.write_graph(graphdef_frozen, save_pb_dir, save_pb_name, as_text=save_pb_as_text)
            return graphdef_frozen
    freeze_graph(
        session,
        [
            'disparities/ExpandDims',
            'disparities/ExpandDims_1'
        ],
        save_pb_dir='.',
        save_pb_name=f'{basename_without_ext}.pb'
    )
    sys.exit(0)


def main():
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)

    model = read_model(args.checkpoint_path, sess)

    with open(args.weights_file, 'wb') as weights_w:
        with open(args.cpp_file, 'w') as cpp_w:
            builder = tensorrt_model_builder.TrtModelBuilder(model, args.net_name, cpp_w, weights_w, args.data_type)
            if args.model_type == 'nvsmall':
                model_nvsmall.create(builder)
            elif args.model_type == 'resnet18':
                model_resnet18.create(builder)
            elif args.model_type == 'resnet18_2D':
                model_resnet18_2D.create(builder)
            else:
                # Should never happen, yeah.
                assert False, 'Not supported.'
    print('Done.')

if __name__ == '__main__':
    main()
ログインするとコメントできます