Open2
ONNX TensorRT の Float16 cache (キャッシュ) を生成するだけのコードスニペット
#! /usr/bin/env python
import onnxruntime
import numpy as np
from argparse import ArgumentParser
class Color:
BLACK = '\033[30m'
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
COLOR_DEFAULT = '\033[39m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
INVISIBLE = '\033[08m'
REVERCE = '\033[07m'
BG_BLACK = '\033[40m'
BG_RED = '\033[41m'
BG_GREEN = '\033[42m'
BG_YELLOW = '\033[43m'
BG_BLUE = '\033[44m'
BG_MAGENTA = '\033[45m'
BG_CYAN = '\033[46m'
BG_WHITE = '\033[47m'
BG_DEFAULT = '\033[49m'
RESET = '\033[0m'
def main():
parser = ArgumentParser()
parser.add_argument(
'-i',
'--input_onnx_file_path',
type=str,
required=True,
help='INPUT ONNX file path',
)
parser.add_argument(
'-s',
'--input_shapes',
type=int,
nargs='+',
action='append',
default=None,
help='INPUT fixed shape. e.g.: 1 3 224 224',
)
parser.add_argument(
'-o',
'--output_cache_folder_path',
type=str,
required=True,
help='OUTPUT cache folder path',
)
args = parser.parse_args()
ONNX_DTYPES_TO_NUMPY_TYPES = {
'tensor(float)': np.float32,
'tensor(double)': np.float64,
'tensor(int8)': np.int8,
'tensor(int16)': np.int16,
'tensor(int32)': np.int32,
'tensor(int64)': np.int64,
'tensor(uint8)': np.uint8,
'tensor(uint16)': np.uint16,
'tensor(uint32)': np.uint32,
'tensor(uint64)': np.uint64,
'tensor(bool)': np.bool_,
}
session_option = onnxruntime.SessionOptions()
session_option.log_severity_level = 3
onnx_session = onnxruntime.InferenceSession(
path_or_bytes=args.input_onnx_file_path,
sess_options=session_option,
providers=[
(
'TensorrtExecutionProvider', {
'trt_engine_cache_enable': True,
'trt_engine_cache_path': args.output_cache_folder_path,
'trt_fp16_enable': True,
}
),
],
)
providers = onnx_session.get_providers()
print(f'{Color.GREEN}INFO:{Color.RESET} providers: {providers}')
input_shapes = args.input_shapes
if input_shapes is None:
inputs = {
input.name: np.ones(
[shape if str(shape).isdecimal() else 1 for shape in input.shape],
dtype=ONNX_DTYPES_TO_NUMPY_TYPES[input.type],
) for input in onnx_session.get_inputs()
}
else:
inputs = {
input.name: np.ones(
[shape if str(shape).isdecimal() else 1 for shape in input_shape],
dtype=ONNX_DTYPES_TO_NUMPY_TYPES[input.type],
) for input, input_shape in zip(onnx_session.get_inputs(), input_shapes)
}
output_names = [
output.name for output in onnx_session.get_outputs()
]
print(f'{Color.GREEN}INFO:{Color.RESET} input_names: {[key for key in inputs.keys()]}')
print(f'{Color.GREEN}INFO:{Color.RESET} input_shapes: {[val.shape for val in inputs.values()]}')
print(f'{Color.GREEN}INFO:{Color.RESET} output_names: {output_names}')
_ = onnx_session.run(
output_names=output_names,
input_feed=inputs,
)
if __name__ == '__main__':
main()