TensorRTを試してみる - TensorRTとは -
はじめに
JetsonNanoでYOLOv4-tinyを動かす試行を行いましたが、
これらはあくまで学習のことも考慮されているDLフレームワークDarknetに任せたSW処理な上、
特に推論の軽量化を行うような工夫なしで試行をしておりました。
そこで今回は軽量化技術を採用した上で通常時とどの程度処理性能の差が生まれるかどうかを確認していくためにNVIDIA製DNNコンパイラであるTensorRTを使いたいと考えたため、
調査及びサンプルを試すところまで確認してみたいと思います。
参考
TensorRTとは
具体的な調査まとめについては別途実施するため、ここではNVIDIA公式Blogにまとめられている簡単な概要のみまとめていきます。
TensorRTはModelを取り込んで最適化をかけるOptimizer, Optimizer後のplanを元にDeployを行うRuntimeの2つに分かれます。
Optimizer
DLフレームワークから出力されたNetwork情報を受け取ってInferenceのための最適化をかけるフェーズです。現VerであるTensorRT 7.2.0ではTensorflow, Caffe, ONNXに対応。
対応しているLayerは限られているが、対応していないLayerを扱いたい場合にはCustom LayerとしてPluginを開発することで対処することも可能(例えばYOLOv4で扱われるMishを使いたい場合にはPluginが必要になる)。
他にもQAT(Quantized Aware Training)されたModelのimportも可能。こちらはDevelopment GuideにTensorFlowによるQAT実施後、Quantized ONNX Modelに変換する手法が紹介されている。
最適化の主な手法は次の4つが挙げられている。
(Multi-Stream Executionはあまり紹介されていなかったため割愛)
-
Layer & Tensor Fusion
上記はInceptionモジュールを取り上げていますが、上記にてConv+bias+Reluを結合していることが分かります。
この結合によって推論実行のレイヤー数やカーネル起動回数が少なくなることによりレイテンシが軽減されます。
結合されるLayerの種類 -
FP16 and INT8 Precision Calibration
モデルデータを量子化(PTQ)することによってメモリ削減および演算量の削減を行うことができます。
なお、Jetson NanoはMaxwellアーキテクチャのため、FP16までしか対応しておりません。
INT8を試す場合にはVoltaアーキテクチャ以降のGPUが必要です。Jetson Xavier NXがお手頃価格です。 -
Kernel Auto-Tuning
Convolution演算で使われるKernelの処理を最適化することができる、と謳われています。 -
Dynamic Tensor Memory
こちらは演算中のメモリ使用期間を指定することによる、メモリ再利用性やメモリ割り当てオーバーヘッドの回避を行うといった機能です。
Runtime
OptimizerによってシリアライズされたPlanと呼ばれるファイルをデシリアライズして実際に動かすためのエンジンです。
RUntime部分は別になっているので、Optimizerはあらかじめ行っておいて、RUntime環境にデプロイする、といった活用が可能です。
参考
NVIDIA Development Guide
TensorRT 3: Faster TensorFlow Inference and Volta Support0
試行
NVIDIA公式で提供されている(ResNet50 ONNX modelを使ったサンプル)[https://developer.nvidia.com/blog/speed-up-inference-tensorrt/]があるので試行してみます。
チュートリアル通りに動かすとコンパイルまでは上手くいきますが、下記でエラーが生じます。
./simpleOnnx_1 resnet50v2/resnet50v2.onnx resnet50v2/test_data_set_0/input_0.pb
...
: ModelImporter.cpp:179: resnetv24_dense0_fwd [Gemm] outputs: [resnetv24_dense0_fwd -> (1, 1000)],
: ModelImporter.cpp:507: Marking resnetv24_dense0_fwd_1 as output: resnetv24_dense0_fwd
ERROR: data: kMIN dimensions in profile 0 are [1,3,256,256] but input has static dimensions [1,3,224,224].
ソースコードを確認すると下記の記述があったため、shapeを修正しました。
おそらくですが、本来は224→256, 1→32にしていることからメモリ最適化をかける処理なのだと思うのですが、
入力データがstaticな次元のものなので非対応だったのかなと想定しています。
// 変更前
auto profile = builder->createOptimizationProfile();
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 256 , 256});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 256 , 256});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{32, 3, 256 , 256});
config->addOptimizationProfile(profile);
// 変更後
auto profile = builder->createOptimizationProfile();
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 224 , 224});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 224 , 224});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{1, 3, 224 , 224});
config->addOptimizationProfile(profile);
上記変更後に動かした結果をteeでロギングして実装と併せて挙動を確認。
- ModelBuild
CUDAで動かすEngineを生成している部分と思われる。
ONNX Modelをパースし、OptimizeをしつつPlan Fileを生成する処理ですね。
nvinfer1::ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
{
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{nvinfer1::createInferBuilder(gLogger)};
unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, gLogger)};
unique_ptr<nvinfer1::IBuilderConfig,Destroy<nvinfer1::IBuilderConfig>> config{builder->createBuilderConfig()};
if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
{
cout << "ERROR: could not parse input engine." << endl;
return nullptr;
}
builder->setMaxBatchSize(batchSize);
config->setMaxWorkspaceSize((1 << 30));
auto profile = builder->createOptimizationProfile();
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 224 , 224});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 224 , 224});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{1, 3, 224 , 224});
config->addOptimizationProfile(profile);
return builder->buildEngineWithConfig(*network, *config);
}
Import処理の内訳
# モデル初期化
: ModelImporter.cpp:202: Adding network input: data with dtype: float32, dimensions: (1, 3, 224, 224)
: ImporterContext.hpp:116: Registering tensor: data for ONNX tensor: data
: ModelImporter.cpp:90: Importing initializer: resnetv24_batchnorm0_gamma
...
# 各Nodeをparseし, input解析&Register
: ModelImporter.cpp:103: Parsing node: resnetv24_batchnorm0_fwd [BatchNormalization]
: ModelImporter.cpp:119: Searching for input: data
: ModelImporter.cpp:119: Searching for input: resnetv24_batchnorm0_gamma
: ModelImporter.cpp:119: Searching for input: resnetv24_batchnorm0_beta
: ModelImporter.cpp:119: Searching for input: resnetv24_batchnorm0_running_mean
: ModelImporter.cpp:119: Searching for input: resnetv24_batchnorm0_running_var
: ModelImporter.cpp:125: resnetv24_batchnorm0_fwd [BatchNormalization] inputs: [data -> (1, 3, 224, 224)], [resnetv24_batchnorm0_gamma -> (3)], [resnetv24_batchnorm0_beta -> (3)], [resnetv24_batchnorm0_running_mean -> (3)], [resnetv24_batchnorm0_running_var -> (3)],
: ImporterContext.hpp:141: Registering layer: resnetv24_batchnorm0_fwd for ONNX node: resnetv24_batchnorm0_fwd
: ImporterContext.hpp:116: Registering tensor: resnetv24_batchnorm0_fwd for ONNX tensor: resnetv24_batchnorm0_fwd
...
# Layer fusion
: After Myelin optimization: 177 layers
: Fusing convolution weights from resnetv24_conv0_fwd with scale resnetv24_batchnorm1_fwd
: Fusing convolution weights from resnetv24_stage1_conv0_fwd with scale resnetv24_stage1_batchnorm1_fwd
: Fusing convolution weights from resnetv24_stage1_conv1_fwd with scale resnetv24_stage1_batchnorm2_fwd
...
# Scale fusion
: After scale fusion: 144 layers
: Fusing resnetv24_conv0_fwd with resnetv24_relu0_fwd
: Fusing resnetv24_stage1_batchnorm0_fwd with resnetv24_stage1_activation0
: Fusing resnetv24_stage1_conv3_fwd with resnetv24_stage1__plus0
...
# dynamic shape処理
: *************** Autotuning format combination: Float(1,224,50176,150528) -> Float(1,112,12544,802816) ***************
...
# いくつか処理候補を作って速度を検証している模様。ここは興味深いですね
: --------------- Timing Runner: resnetv24_conv0_fwd + resnetv24_relu0_fwd (CaskConvolution)
: resnetv24_conv0_fwd + resnetv24_relu0_fwd (scudnn) Set Tactic Name: maxwell_scudnn_128x32_relu_medium_nn_v1
: Tactic: 1062367460111450758 time 22.0161
: resnetv24_conv0_fwd + resnetv24_relu0_fwd (scudnn) Set Tactic Name: maxwell_scudnn_128x64_relu_large_nn_v1
: Tactic: 4337000649858996379 time 15.3781
: resnetv24_conv0_fwd + resnetv24_relu0_fwd (scudnn) Set Tactic Name: maxwell_scudnn_128x128_relu_medium_nn_v1
: Tactic: 4501471010995462441 time 6.96667
: resnetv24_conv0_fwd + resnetv24_relu0_fwd (scudnn) Set Tactic Name: maxwell_scudnn_128x64_relu_medium_nn_v1
: Tactic: 6645123197870846056 time 3.50708
: resnetv24_conv0_fwd + resnetv24_relu0_fwd (scudnn) Set Tactic Name: maxwell_scudnn_128x128_relu_large_nn_v1
: Tactic: -9137461792520977713 time 7.2326
: resnetv24_conv0_fwd + resnetv24_relu0_fwd (scudnn) Set Tactic Name: maxwell_scudnn_128x32_relu_large_nn_v1
: Tactic: -6092040395344634144 time 6.24083
: Fastest Tactic: 6645123197870846056 Time: 3.50708
...
# Optimize結果の出力
: After reformat layers: 78 layers
: Block size 1073741824
: Block size 3211264
: Block size 3211264
: Block size 3211264
: Block size 401408
: Total Activation Memory: 1083777024
- Runtime用初期化
CUDAの初期化を実行。
初期化範囲はengine側から取得し、Input TensorはONNX Modelから取得します。
for (int i = 0; i < engine->getNbBindings(); ++i)
{
Dims dims{engine->getBindingDimensions(i)};
size_t size = accumulate(dims.d+1, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
// Create CUDA buffer for Tensor.
cudaMalloc(&bindings[i], batchSize * size * sizeof(float));
// Resize CPU buffers to fit Tensor.
if (engine->bindingIsInput(i)){
inputTensor.resize(size);
}
else
outputTensor.resize(size);
}
// Read input tensor from ONNX file.
if (readTensor(inputFiles, inputTensor) != inputTensor.size())
{
cout << "Couldn't read input Tensor" << endl;
return 1;
}
- Contextの取得
Runtime実行に必要な設定の取得。Engineから取得。
// Create Execution Context.
context.reset(engine->createExecutionContext());
Dims dims_i{engine->getBindingDimensions(0)};
Dims4 inputDims{batchSize, dims_i.d[1], dims_i.d[2], dims_i.d[3]};
context->setBindingDimensions(0, inputDims);
- Runtime実行
下記で実行している模様。
フローとしては、cudaMemcpyAsyncで非同期にホストからCUDAに対して入力データをコピーし、enqueueV2で推論要求をキューイング、
cudaMemcpyAsyncで非同期にCUDAからホストに対して出力データをコピー, という流れの模様。
void launchInference(IExecutionContext* context, cudaStream_t stream, vector<float> const& inputTensor, vector<float>& outputTensor, void** bindings, int batchSize)
{
int inputId = getBindingInputIndex(context);
cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
context->enqueueV2(bindings, stream, nullptr);
cudaMemcpyAsync(outputTensor.data(), bindings[1 - inputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
}
上記処理の後、cudaStreamSynchronizeによってGPU演算がすべて終わるまで待機します。
その後のコードは結果確認およびcudaFreeによるメモリ開放処理を行っております。
-
Profiling
CUDAのプロファイリング手法については下記にまとまっていましたのでリンクをつけておきます。
(CUDAコードの実行時間を測定する方法のまとめ)[https://qiita.com/syo0901/items/7ea3b8dfc01fd5cc2cf4]
公式ではCudaEVENTによるプロファイル方法が紹介されていました。 -
FP16 and INT8 Precision Calibration
下記のFlagを追加する必要あり。
INT8を使う場合は加えてネットワーク自体のスケーリングが必要です。
// FP16
config->setFlag(BuilderFlag::kFP16);
// INT8
inline void setAllTensorScales(INetworkDefinition* network, float inScales = 2.0f, float outScales = 4.0f)
{
// Ensure that all layer inputs have a scale.
for (int i = 0; i < network->getNbLayers(); i++)
{
auto layer = network->getLayer(i);
for (int j = 0; j < layer->getNbInputs(); j++)
{
ITensor* input{layer->getInput(j)};
// Optional inputs are nullptr here and are from RNN layers.
if (input != nullptr && !input->dynamicRangeIsSet())
{
ASSERT(input->setDynamicRange(-inScales, inScales));
}
}
}
// Ensure that all layer outputs have a scale.
// Tensors that are also inputs to layers are ingored here
// since the previous loop nest assigned scales to them.
for (int i = 0; i < network->getNbLayers(); i++)
{
auto layer = network->getLayer(i);
for (int j = 0; j < layer->getNbOutputs(); j++)
{
ITensor* output{layer->getOutput(j)};
// Optional outputs are nullptr here and are from RNN layers.
if (output != nullptr && !output->dynamicRangeIsSet())
{
// Pooling must have the same input and output scales.
if (layer->getType() == LayerType::kPOOLING)
{
ASSERT(output->setDynamicRange(-inScales, inScales));
}
else
{
ASSERT(output->setDynamicRange(-outScales, outScales));
}
}
}
}
}
config->setFlag(BuilderFlag::kINT8);
setAllTensorScales(network.get(), 127.0f, 127.0f);
最後に
今回はTensorRTについてサンプルを動かしつつ簡単なフローの確認を行ってみました。
今回はパフォーマンスがどれくらい変わるかなどの検証を行っていないですしFP16による効果確認も出来てないので、
次回はどの程度の効果があるかを確認しながら、YOLOv4-tinyの最適化をTensorRTで試していきたいと思います。
Discussion