Open2

ONNX TensorRT の Float16 cache (キャッシュ) を生成するだけのコードスニペット

PINTOPINTO
#! /usr/bin/env python

import onnxruntime
import numpy as np
from argparse import ArgumentParser


class Color:
    BLACK          = '\033[30m'
    RED            = '\033[31m'
    GREEN          = '\033[32m'
    YELLOW         = '\033[33m'
    BLUE           = '\033[34m'
    MAGENTA        = '\033[35m'
    CYAN           = '\033[36m'
    WHITE          = '\033[37m'
    COLOR_DEFAULT  = '\033[39m'
    BOLD           = '\033[1m'
    UNDERLINE      = '\033[4m'
    INVISIBLE      = '\033[08m'
    REVERCE        = '\033[07m'
    BG_BLACK       = '\033[40m'
    BG_RED         = '\033[41m'
    BG_GREEN       = '\033[42m'
    BG_YELLOW      = '\033[43m'
    BG_BLUE        = '\033[44m'
    BG_MAGENTA     = '\033[45m'
    BG_CYAN        = '\033[46m'
    BG_WHITE       = '\033[47m'
    BG_DEFAULT     = '\033[49m'
    RESET          = '\033[0m'


def main():
    parser = ArgumentParser()
    parser.add_argument(
        '-i',
        '--input_onnx_file_path',
        type=str,
        required=True,
        help='INPUT ONNX file path',
    )
    parser.add_argument(
        '-s',
        '--input_shapes',
        type=int,
        nargs='+',
        action='append',
        default=None,
        help='INPUT fixed shape. e.g.: 1 3 224 224',
    )
    parser.add_argument(
        '-o',
        '--output_cache_folder_path',
        type=str,
        required=True,
        help='OUTPUT cache folder path',
    )
    args = parser.parse_args()

    ONNX_DTYPES_TO_NUMPY_TYPES = {
        'tensor(float)': np.float32,
        'tensor(double)': np.float64,
        'tensor(int8)': np.int8,
        'tensor(int16)': np.int16,
        'tensor(int32)': np.int32,
        'tensor(int64)': np.int64,
        'tensor(uint8)': np.uint8,
        'tensor(uint16)': np.uint16,
        'tensor(uint32)': np.uint32,
        'tensor(uint64)': np.uint64,
        'tensor(bool)': np.bool_,
    }

    session_option = onnxruntime.SessionOptions()
    session_option.log_severity_level = 3
    onnx_session = onnxruntime.InferenceSession(
        path_or_bytes=args.input_onnx_file_path,
        sess_options=session_option,
        providers=[
            (
                'TensorrtExecutionProvider', {
                    'trt_engine_cache_enable': True,
                    'trt_engine_cache_path': args.output_cache_folder_path,
                    'trt_fp16_enable': True,
                }
            ),
        ],
    )
    providers = onnx_session.get_providers()
    print(f'{Color.GREEN}INFO:{Color.RESET} providers: {providers}')

    input_shapes = args.input_shapes

    if input_shapes is None:
        inputs = {
            input.name: np.ones(
                [shape if str(shape).isdecimal() else 1 for shape in input.shape],
                dtype=ONNX_DTYPES_TO_NUMPY_TYPES[input.type],
            ) for input in onnx_session.get_inputs()
        }
    else:
        inputs = {
            input.name: np.ones(
                [shape if str(shape).isdecimal() else 1 for shape in input_shape],
                dtype=ONNX_DTYPES_TO_NUMPY_TYPES[input.type],
            ) for input, input_shape in zip(onnx_session.get_inputs(), input_shapes)
        }
    output_names = [
        output.name for output in onnx_session.get_outputs()
    ]

    print(f'{Color.GREEN}INFO:{Color.RESET} input_names: {[key for key in inputs.keys()]}')
    print(f'{Color.GREEN}INFO:{Color.RESET} input_shapes: {[val.shape for val in inputs.values()]}')
    print(f'{Color.GREEN}INFO:{Color.RESET} output_names: {output_names}')

    _ = onnx_session.run(
        output_names=output_names,
        input_feed=inputs,
    )

if __name__ == '__main__':
    main()