🤪

Add a custom OP to the TFLite runtime to build the whl installer

2021/01/23に公開
2

This is the procedure for generating a TensorFlow Lite runtime capable of inferring custom operations for MediaPipe. MaxPoolingWithArgmax2D, MaxUnpooling2D, Convolution2DTransposeBias

1. Environment

  • Ubuntu 18.04 x86_64
  • Python 3.6
  • TensorFlow v2.4.0
  • MediaPipe 0.8.2
  • Docker

2. Procedure

$ cd ${HOME} && mkdir tflitecustom && cd tflitecustom
$ git clone -b v2.4.0 https://github.com/tensorflow/tensorflow.git
$ git clone -b 0.8.2 https://github.com/google/mediapipe.git

1

Navigate to the path tensorflow/lite/kernels from tensorflow root directory and do the following steps.

1.Paste the custom operation files in the above said path. (In my case i have pasted max_pool_argmax.cc, max_pool_argmax.cc, max_unpooling.cc, max_unpooling.h, transpose_conv_bias.cc, transpose_conv_bias.h)

$ cd mediapipe/mediapipe/util/tflite/operations && \
  cp max_pool_argmax.cc ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
  cp max_pool_argmax.h ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
  cp max_unpooling.cc ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
  cp max_unpooling.h ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
  cp transpose_conv_bias.cc ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels && \
  cp transpose_conv_bias.h ${HOME}/tflitecustom/tensorflow/tensorflow/lite/kernels
$ cd ${HOME}/tflitecustom

2
3

Change the namespase of the six files. max_pool_argmax.h max_unpooling.h transpose_conv_bias.h max_pool_argmax.cc max_unpooling.cc transpose_conv_bias.cc

tensorflow/tensorflow/lite/kernels/max_pool_argmax.h
#ifndef TENSORFLOW_LITE_KERNELS_MAX_POOL_ARGMAX_H_
#define TENSORFLOW_LITE_KERNELS_MAX_POOL_ARGMAX_H_

#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {

TfLiteRegistration* RegisterMaxPoolingWithArgmax2D();

}  // namespace custom
}  // namespace ops
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_MAX_POOL_ARGMAX_H_
tensorflow/tensorflow/lite/kernels/max_unpooling.h
#ifndef TENSORFLOW_LITE_KERNELS_MAX_UNPOOLING_H_
#define TENSORFLOW_LITE_KERNELS_MAX_UNPOOLING_H_

#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {

TfLiteRegistration* RegisterMaxUnpooling2D();

}  // namespace custom
}  // namespace ops
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_MAX_UNPOOLING_H_
tensorflow/tensorflow/lite/kernels/transpose_conv_bias.h
#ifndef TENSORFLOW_LITE_KERNELS_TRANSPOSE_CONV_BIAS_H_
#define TENSORFLOW_LITE_KERNELS_TRANSPOSE_CONV_BIAS_H_

#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {

TfLiteRegistration* RegisterConvolution2DTransposeBias();

}  // namespace custom
}  // namespace ops
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_TRANSPOSE_CONV_BIAS_H_
tensorflow/tensorflow/lite/kernels/max_pool_argmax.cc
#include "tensorflow/lite/kernels/max_pool_argmax.h"

#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"

namespace tflite {
namespace ops {
namespace custom {
namespace MaxPoolingWithArgmax2D {

constexpr int kDataInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kIndicesTensor = 1;
:
:
// End of MediaPipe modification.
// End of copy.

} // namespace MaxPoolingWithArgmax2D

TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
  static TfLiteRegistration reg = {
      [](TfLiteContext*, const char*, size_t) -> void* {
        return new TfLitePaddingValues();
      },
      [](TfLiteContext*, void* buffer) -> void {
        delete reinterpret_cast<TfLitePaddingValues*>(buffer);
      },
      MaxPoolingWithArgmax2D::Prepare, MaxPoolingWithArgmax2D::Eval};
  return &reg;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite
tensorflow/tensorflow/lite/kernels/max_pool_argmax.cc
#include "tensorflow/lite/kernels/max_unpooling.h"

#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"

namespace tflite {
namespace ops {
namespace custom {
namespace MaxUnpooling2D {

constexpr int kDataInputTensor = 0;
constexpr int kIndicesTensor = 1;
constexpr int kOutputTensor = 0;
:
:
  return kTfLiteOk;
}

} // namespace MaxUnpooling2D

TfLiteRegistration* RegisterMaxUnpooling2D() {
  static TfLiteRegistration reg = {
      [](TfLiteContext*, const char*, size_t) -> void* {
        return new TfLitePaddingValues();
      },
      [](TfLiteContext*, void* buffer) -> void {
        delete reinterpret_cast<TfLitePaddingValues*>(buffer);
      },
      MaxUnpooling2D::Prepare, MaxUnpooling2D::Eval};
  return &reg;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite
tensorflow/tensorflow/lite/kernels/max_pool_argmax.cc
#include "tensorflow/lite/kernels/transpose_conv_bias.h"

#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"

namespace tflite {
namespace ops {
namespace custom {
namespace Convolution2DTransposeBias {

constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kDataInputTensor = 0;
constexpr int kOutputTensor = 0;
:
:
  return kTfLiteOk;
}
// End of copy.

} // namespace Convolution2DTransposeBias

TfLiteRegistration* RegisterConvolution2DTransposeBias() {
  static TfLiteRegistration reg = {nullptr, nullptr, Convolution2DTransposeBias::Prepare, Convolution2DTransposeBias::Eval};
  return &reg;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite

2.Make changes in the following files (register.cc and register_ref.cc) as mentioned below. First add the following highlighted lines under namespace custom as below

From:tensorflow/tensorflow/lite/kernels/register.cc
namespace tflite {
namespace ops {
namespace custom {

TfLiteRegistration* Register_NUMERIC_VERIFY();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();

}  // namespace custom
To:tensorflow/tensorflow/lite/kernels/register.cc
namespace tflite {
namespace ops {
namespace custom {

TfLiteRegistration* Register_NUMERIC_VERIFY();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D();
TfLiteRegistration* RegisterMaxUnpooling2D();
TfLiteRegistration* RegisterConvolution2DTransposeBias();

}  // namespace custom
From:tensorflow/tensorflow/lite/kernels/register_ref.cc
namespace tflite {
namespace ops {

namespace custom {

TfLiteRegistration* Register_NUMERIC_VERIFY_REF();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();

}  // namespace custom
To:tensorflow/tensorflow/lite/kernels/register_ref.cc
namespace tflite {
namespace ops {

namespace custom {

TfLiteRegistration* Register_NUMERIC_VERIFY_REF();
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D();
TfLiteRegistration* RegisterMaxUnpooling2D();
TfLiteRegistration* RegisterConvolution2DTransposeBias();

}  // namespace custom

3.Next add the following under BuiltinOpResolver method

  AddCustom("MaxPoolingWithArgmax2D", tflite::ops::custom::RegisterMaxPoolingWithArgmax2D());
  AddCustom("MaxUnpooling2D", tflite::ops::custom::RegisterMaxUnpooling2D());
  AddCustom("Convolution2DTransposeBias", flite::ops::custom::RegisterConvolution2DTransposeBias());
From:tensorflow/tensorflow/lite/kernels/register.cc
namespace builtin {

BuiltinOpResolver::BuiltinOpResolver() {
  AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
  AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
   :
   :
  AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
  AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
             /* min_version = */ 1,
             /* max_version = */ 3);
  AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
  AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
To:tensorflow/tensorflow/lite/kernels/register.cc
namespace builtin {

BuiltinOpResolver::BuiltinOpResolver() {
  AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
  AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
   :
   :
  AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
  AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
             /* min_version = */ 1,
             /* max_version = */ 3);
  AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
  AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());
  AddCustom("MaxPoolingWithArgmax2D", tflite::ops::custom::RegisterMaxPoolingWithArgmax2D());
  AddCustom("MaxUnpooling2D", tflite::ops::custom::RegisterMaxUnpooling2D());
  AddCustom("Convolution2DTransposeBias", tflite::ops::custom::RegisterConvolution2DTransposeBias());
}
From:tensorflow/tensorflow/lite/kernels/register_ref.cc
namespace builtin {

BuiltinOpResolver::BuiltinOpResolver() {
  AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
  AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
   :
   :
  AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
  AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
             /* min_version = */ 1,
             /* max_version = */ 3);
  AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
  AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
To:tensorflow/tensorflow/lite/kernels/register_ref.cc
namespace builtin {

BuiltinOpResolver::BuiltinOpResolver() {
  AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
  AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
             /* max_version = */ 2);
  AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
   :
   :
  AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
  AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
             /* min_version = */ 1,
             /* max_version = */ 3);
  AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
  AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());
  AddCustom("MaxPoolingWithArgmax2D", tflite::ops::custom::RegisterMaxPoolingWithArgmax2D());
  AddCustom("MaxUnpooling2D", tflite::ops::custom::RegisterMaxUnpooling2D());
  AddCustom("Convolution2DTransposeBias", tflite::ops::custom::RegisterConvolution2DTransposeBias());
}

4.Add the following to BUILD file from the above said path

From:tensorflow/tensorflow/lite/kernels/BUILD
cc_library(
    name = "builtin_op_kernels",
    srcs = BUILTIN_KERNEL_SRCS,
    hdrs = [
        "dequantize.h",
    ],
    compatible_with = get_compatible_with_portable(),
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = BUILTIN_KERNEL_DEPS + [
        "@ruy//ruy/profiler:instrumentation",
        "//tensorflow/lite/kernels/internal:cppmath",
        "//tensorflow/lite:string",
        "@farmhash_archive//:farmhash",
    ],
)
To:tensorflow/tensorflow/lite/kernels/BUILD
cc_library(
    name = "builtin_op_kernels",
    srcs = BUILTIN_KERNEL_SRCS + [
        "max_pool_argmax.cc",
        "max_unpooling.cc",
        "transpose_conv_bias.cc",
    ],
    hdrs = [
        "dequantize.h",
        "max_pool_argmax.h",
        "max_unpooling.h",
        "transpose_conv_bias.h",
    ],
    compatible_with = get_compatible_with_portable(),
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = BUILTIN_KERNEL_DEPS + [
        "@ruy//ruy/profiler:instrumentation",
        "//tensorflow/lite/kernels/internal:cppmath",
        "//tensorflow/lite:string",
        "@farmhash_archive//:farmhash",
    ],
)

5.Modify the Dockerfile for building that is available in the official TensorFlow repository.

tensorflow/tensorflow/tools/ci_build/Dockerfile.cpu
$ cd ${HOME}/tflitecustom/tensorflow
$ nano tensorflow/tools/ci_build/Dockerfile.cpu

FROM ubuntu:16.04
↓
FROM ubuntu:18.04

RUN add-apt-repository -y ppa:openjdk-r/ppa && \
    add-apt-repository -y ppa:george-edison55/cmake-3.x
↓
RUN add-apt-repository -y ppa:openjdk-r/ppa

$ nano tensorflow/tools/ci_build/install/install_deb_packages.sh

# Install dependencies from ubuntu deb repository.
apt-key adv --keyserver keyserver.ubuntu.com --recv 084ECFC5828AB726
apt-get update
↓
# Install dependencies from ubuntu deb repository.
apt-get update
apt-get install dirmngr -y
apt-key adv --keyserver keyserver.ubuntu.com --recv 084ECFC5828AB726
apt-get update

6.Build the Wheel installer for TensorFlow.

$ sudo CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3.6 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.6" \
  tensorflow/tools/ci_build/ci_build.sh CPU \
  tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh native

7.Use the generated installer to install tflite_runtime and check its operation.

Install
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly tensorboard \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator \
                      tflite_runtime -y
$ sudo cp  tensorflow/lite/tools/pip_package/gen/tflite_pip/python3.6/dist/tflite_runtime-2.4.1-py3-none-any.whl ${HOME}/tflitecustom
$ cd ${HOME}/tflitecustom
$ sudo pip3 install --upgrade tflite_runtime-2.4.1-py3-none-any.whl
Test
$ python3
>>> from tflite_runtime.interpreter import Interpreter
>>> interpreter = Interpreter(model_path="segm_full_v679.tflite")
>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()
>>> import pprint
>>> pprint.pprint(input_details)
[{'dtype': <class 'numpy.float32'>,
  'index': 0,
  'name': 'input_1',
  'quantization': (0.0, 0),
  'quantization_parameters': {'quantized_dimension': 0,
                              'scales': array([], dtype=float32),
                              'zero_points': array([], dtype=int32)},
  'shape': array([  1, 144, 256,   3], dtype=int32),
  'shape_signature': array([  1, 144, 256,   3], dtype=int32),
  'sparsity_parameters': {}}]
>>> pprint.pprint(output_details)
[{'dtype': <class 'numpy.float32'>,
  'index': 244,
  'name': 'segment',
  'quantization': (0.0, 0),
  'quantization_parameters': {'quantized_dimension': 0,
                              'scales': array([], dtype=float32),
                              'zero_points': array([], dtype=int32)},
  'shape': array([  1, 144, 256,   2], dtype=int32),
  'shape_signature': array([  1, 144, 256,   2], dtype=int32),
  'sparsity_parameters': {}}]
>>> exit()

Discussion

KieranKieran

I'm having trouble compiling this with cmake, but it should be similar? Maybe something has been updated?
The primary issue being there is no register.cc file at all anymore.
The transpose_conv_bias also fails to find the header even when you leave it in the same folder.

Any assistance???