Add a custom OP to the TFLite runtime to build the whl installer
This is the procedure for generating a TensorFlow Lite runtime capable of inferring custom operations for MediaPipe. MaxPoolingWithArgmax2D
, MaxUnpooling2D
, Convolution2DTransposeBias
1. Environment
- Ubuntu 18.04 x86_64
- Python 3.6
- TensorFlow v2.4.0
- MediaPipe 0.8.2
- Docker
2. Procedure
$ cd ${HOME} && mkdir tflitecustom && cd tflitecustom
$ git clone -b v2.4.0 https://github.com/tensorflow/tensorflow.git
$ git clone -b 0.8.2 https://github.com/google/mediapipe.git
Navigate to the path tensorflow/lite/kernels
from tensorflow root directory and do the following steps.
1.Paste the custom operation files in the above said path. (In my case i have pasted max_pool_argmax.cc
, max_pool_argmax.cc
, max_unpooling.cc
, max_unpooling.h
, transpose_conv_bias.cc
, transpose_conv_bias.h
)
$ cd mediapipe/mediapipe/util/tflite/operations && \
cp max_pool_argmax.cc ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
cp max_pool_argmax.h ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
cp max_unpooling.cc ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
cp max_unpooling.h ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
cp transpose_conv_bias.cc ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
cp transpose_conv_bias.h ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels
$ cd ${HOME}/tflitecustom
Change the namespase of the six files. max_pool_argmax.h
max_unpooling.h
transpose_conv_bias.h
max_pool_argmax.cc
max_unpooling.cc
transpose_conv_bias.cc
#ifndef TENSORFLOW_LITE_KERNELS_MAX_POOL_ARGMAX_H_
#define TENSORFLOW_LITE_KERNELS_MAX_POOL_ARGMAX_H_
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D();
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_MAX_POOL_ARGMAX_H_
#ifndef TENSORFLOW_LITE_KERNELS_MAX_UNPOOLING_H_
#define TENSORFLOW_LITE_KERNELS_MAX_UNPOOLING_H_
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* RegisterMaxUnpooling2D();
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_MAX_UNPOOLING_H_
#ifndef TENSORFLOW_LITE_KERNELS_TRANSPOSE_CONV_BIAS_H_
#define TENSORFLOW_LITE_KERNELS_TRANSPOSE_CONV_BIAS_H_
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* RegisterConvolution2DTransposeBias();
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_TRANSPOSE_CONV_BIAS_H_
#include "tensorflow/lite/kernels/max_pool_argmax.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace custom {
namespace MaxPoolingWithArgmax2D {
constexpr int kDataInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kIndicesTensor = 1;
:
:
// End of MediaPipe modification.
// End of copy.
} // namespace MaxPoolingWithArgmax2D
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
static TfLiteRegistration reg = {
[](TfLiteContext*, const char*, size_t) -> void* {
return new TfLitePaddingValues();
},
[](TfLiteContext*, void* buffer) -> void {
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
},
MaxPoolingWithArgmax2D::Prepare, MaxPoolingWithArgmax2D::Eval};
return ®
}
} // namespace custom
} // namespace ops
} // namespace tflite
#include "tensorflow/lite/kernels/max_unpooling.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace custom {
namespace MaxUnpooling2D {
constexpr int kDataInputTensor = 0;
constexpr int kIndicesTensor = 1;
constexpr int kOutputTensor = 0;
:
:
return kTfLiteOk;
}
} // namespace MaxUnpooling2D
TfLiteRegistration* RegisterMaxUnpooling2D() {
static TfLiteRegistration reg = {
[](TfLiteContext*, const char*, size_t) -> void* {
return new TfLitePaddingValues();
},
[](TfLiteContext*, void* buffer) -> void {
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
},
MaxUnpooling2D::Prepare, MaxUnpooling2D::Eval};
return ®
}
} // namespace custom
} // namespace ops
} // namespace tflite
#include "tensorflow/lite/kernels/transpose_conv_bias.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace custom {
namespace Convolution2DTransposeBias {
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kDataInputTensor = 0;
constexpr int kOutputTensor = 0;
:
:
return kTfLiteOk;
}
// End of copy.
} // namespace Convolution2DTransposeBias
TfLiteRegistration* RegisterConvolution2DTransposeBias() {
static TfLiteRegistration reg = {nullptr, nullptr, Convolution2DTransposeBias::Prepare, Convolution2DTransposeBias::Eval};
return ®
}
} // namespace custom
} // namespace ops
} // namespace tflite
2.Make changes in the following files (register.cc
and register_ref.cc
) as mentioned below. First add the following highlighted lines under namespace custom as below
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_NUMERIC_VERIFY();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
} // namespace custom
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_NUMERIC_VERIFY();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D();
TfLiteRegistration* RegisterMaxUnpooling2D();
TfLiteRegistration* RegisterConvolution2DTransposeBias();
} // namespace custom
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_NUMERIC_VERIFY_REF();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
} // namespace custom
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_NUMERIC_VERIFY_REF();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D();
TfLiteRegistration* RegisterMaxUnpooling2D();
TfLiteRegistration* RegisterConvolution2DTransposeBias();
} // namespace custom
3.Next add the following under BuiltinOpResolver
method
AddCustom("MaxPoolingWithArgmax2D", tflite::ops::custom::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D", tflite::ops::custom::RegisterMaxUnpooling2D());
AddCustom("Convolution2DTransposeBias", flite::ops::custom::RegisterConvolution2DTransposeBias());
namespace builtin {
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
:
:
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
namespace builtin {
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
:
:
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
AddCustom("MaxPoolingWithArgmax2D", tflite::ops::custom::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D", tflite::ops::custom::RegisterMaxUnpooling2D());
AddCustom("Convolution2DTransposeBias", tflite::ops::custom::RegisterConvolution2DTransposeBias());
}
namespace builtin {
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
:
:
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
namespace builtin {
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
:
:
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
AddCustom("MaxPoolingWithArgmax2D", tflite::ops::custom::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D", tflite::ops::custom::RegisterMaxUnpooling2D());
AddCustom("Convolution2DTransposeBias", tflite::ops::custom::RegisterConvolution2DTransposeBias());
}
4.Add the following to BUILD
file from the above said path
cc_library(
name = "builtin_op_kernels",
srcs = BUILTIN_KERNEL_SRCS,
hdrs = [
"dequantize.h",
],
compatible_with = get_compatible_with_portable(),
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"],
deps = BUILTIN_KERNEL_DEPS + [
"@ruy//ruy/profiler:instrumentation",
"//tensorflow/lite/kernels/internal:cppmath",
"//tensorflow/lite:string",
"@farmhash_archive//:farmhash",
],
)
cc_library(
name = "builtin_op_kernels",
srcs = BUILTIN_KERNEL_SRCS + [
"max_pool_argmax.cc",
"max_unpooling.cc",
"transpose_conv_bias.cc",
],
hdrs = [
"dequantize.h",
"max_pool_argmax.h",
"max_unpooling.h",
"transpose_conv_bias.h",
],
compatible_with = get_compatible_with_portable(),
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"],
deps = BUILTIN_KERNEL_DEPS + [
"@ruy//ruy/profiler:instrumentation",
"//tensorflow/lite/kernels/internal:cppmath",
"//tensorflow/lite:string",
"@farmhash_archive//:farmhash",
],
)
5.Modify the Dockerfile
for building that is available in the official TensorFlow repository.
$ cd ${HOME}/tflitecustom/tensorflow
$ nano tensorflow/tools/ci_build/Dockerfile.cpu
FROM ubuntu:16.04
↓
FROM ubuntu:18.04
RUN add-apt-repository -y ppa:openjdk-r/ppa && \
add-apt-repository -y ppa:george-edison55/cmake-3.x
↓
RUN add-apt-repository -y ppa:openjdk-r/ppa
$ nano tensorflow/tools/ci_build/install/install_deb_packages.sh
# Install dependencies from ubuntu deb repository.
apt-key adv --keyserver keyserver.ubuntu.com --recv 084ECFC5828AB726
apt-get update
↓
# Install dependencies from ubuntu deb repository.
apt-get update
apt-get install dirmngr -y
apt-key adv --keyserver keyserver.ubuntu.com --recv 084ECFC5828AB726
apt-get update
6.Build the Wheel installer for TensorFlow.
$ sudo CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3.6 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.6" \
tensorflow/tools/ci_build/ci_build.sh CPU \
tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh native
7.Use the generated installer to install tflite_runtime and check its operation.
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly tensorboard \
tf-estimator-nightly tensorflow-gpu \
tensorflow tf-nightly tensorflow_estimator \
tflite_runtime -y
$ sudo cp tensorflow/lite/tools/pip_package/gen/tflite_pip/python3.6/dist/tflite_runtime-2.4.1-py3-none-any.whl ${HOME}/tflitecustom
$ cd ${HOME}/tflitecustom
$ sudo pip3 install --upgrade tflite_runtime-2.4.1-py3-none-any.whl
$ python3
>>> from tflite_runtime.interpreter import Interpreter
>>> interpreter = Interpreter(model_path="segm_full_v679.tflite")
>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()
>>> import pprint
>>> pprint.pprint(input_details)
[{'dtype': <class 'numpy.float32'>,
'index': 0,
'name': 'input_1',
'quantization': (0.0, 0),
'quantization_parameters': {'quantized_dimension': 0,
'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32)},
'shape': array([ 1, 144, 256, 3], dtype=int32),
'shape_signature': array([ 1, 144, 256, 3], dtype=int32),
'sparsity_parameters': {}}]
>>> pprint.pprint(output_details)
[{'dtype': <class 'numpy.float32'>,
'index': 244,
'name': 'segment',
'quantization': (0.0, 0),
'quantization_parameters': {'quantized_dimension': 0,
'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32)},
'shape': array([ 1, 144, 256, 2], dtype=int32),
'shape_signature': array([ 1, 144, 256, 2], dtype=int32),
'sparsity_parameters': {}}]
>>> exit()
Discussion
I'm having trouble compiling this with cmake, but it should be similar? Maybe something has been updated?
The primary issue being there is no register.cc file at all anymore.
The transpose_conv_bias also fails to find the header even when you leave it in the same folder.
Any assistance???
This is an article I wrote about 2 years ago, so it is not surprising that there is an error. The TensorFlow code hierarchy changes frequently.
Therefore, you will need to make modifications along with the source code for the version you are trying to customize.
I have committed a patch file for TensorFlow v2.11.0, so please read it and customize it as you like.