🤪

Custom Operation入りのtfliteを逆コンバートしてJSON化し標準OPへ置き換えたうえでtfliteを再生成する方法

43 min read

1. はじめに

記念すべきZennでの記事初投稿です。 お試しで記事にしますので少々ラフな点はご容赦願います。 今回は、 カスタムオペレーション入りtflite -> JSON化 -> JSON編集してカスタムオペレーションを削除 -> JSONをtflite変換 (tflite to JSON to tflite) の手順を記載します。

MediaPipeで公開されているtfliteモデルはとても性能が高いにも関わらず、カスタムオペレーションが含まれていてお手軽にプラットフォーム間で共用できなかったり、公開されているモデルのトレーニング用コードやモデル構造が定義されたPythonプログラムが公開されていなかったりと、とても歯がゆい思いをさせられます。 Apache2.0ライセンスで公開されているにも関わらず、です。 ということで、tfliteをハックして改造する手段を整理します。

2. 環境

  1. Ubuntu 18.04 x86_64
  2. TensorFlow 2.4.0
  3. FlatBuffers 1.12.0
  4. MediaPipe Meet Segmentation (Google Meet)

3. 手順

3-1. flatcのビルド

tflite を逆コンバートしてJSON化するツール flatc をビルドして手に入れます。

flatcのビルド
$ git clone -b v1.12.0 https://github.com/google/flatbuffers.git
$ cd flatbuffers && mkdir build && cd build
$ cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release ..
$ make -j8

flatbuffers1

生成されたファイルの確認
$ ls -l

合計 5664
-rw-rw-r--  1 x x   20733  1月  8 23:49 CMakeCache.txt
drwxrwxr-x 13 x x    4096  1月  8 23:50 CMakeFiles
-rw-r--r--  1 x x    3578  1月  8 23:49 CPackConfig.cmake
-rw-r--r--  1 x x    4025  1月  8 23:49 CPackSourceConfig.cmake
-rw-rw-r--  1 x x     521  1月  8 23:49 CTestTestfile.cmake
-rw-rw-r--  1 x x     368  1月  8 23:49 FlatbuffersConfigVersion.cmake
-rw-rw-r--  1 x x   52147  1月  8 23:49 Makefile
-rw-rw-r--  1 x x    5821  1月  8 23:49 cmake_install.cmake
-rwxrwxr-x  1 x x 2526272  1月  8 23:50 flatc
-rwxrwxr-x  1 x x   23728  1月  8 23:49 flathash
-rwxrwxr-x  1 x x  546192  1月  8 23:50 flatsamplebfbs
-rwxrwxr-x  1 x x   51872  1月  8 23:50 flatsamplebinary
-rwxrwxr-x  1 x x  546080  1月  8 23:50 flatsampletext
-rwxrwxr-x  1 x x 1074608  1月  8 23:50 flattests
-rw-rw-r--  1 x x  902110  1月  8 23:49 libflatbuffers.a
drwxrwxr-x  4 x x    4096  1月  8 23:49 samples
drwxrwxr-x 17 x x    4096  1月  8 23:49 tests

flatbuffers2

flatc のコマンドラインオプションです。 何気に高機能ですね。

flatcのコマンドラインオプション
Usage: ./flatc [OPTION]... FILE... [-- FILE...]
  --binary         -b    Generate wire format binaries for any data definitions.
  --json           -t    Generate text output for any data definitions.
  --cpp            -c    Generate C++ headers for tables/structs.
  --go             -g    Generate Go files for tables/structs.
  --java           -j    Generate Java classes for tables/structs.
  --js             -s    Generate JavaScript code for tables/structs.
  --dart           -d    Generate Dart classes for tables/structs.
  --ts             -T    Generate TypeScript code for tables/structs.
  --csharp         -n    Generate C# classes for tables/structs.
  --python         -p    Generate Python files for tables/structs.
  --lobster              Generate Lobster files for tables/structs.
  --lua            -l    Generate Lua files for tables/structs.
  --rust           -r    Generate Rust files for tables/structs.
  --php                  Generate PHP files for tables/structs.
  --kotlin               Generate Kotlin classes for tables/structs.
  --jsonschema           Generate Json schema.
  --swift                Generate Swift files for tables/structs.
  -o PATH                Prefix PATH to all generated files.
  -I PATH                Search for includes in the specified path.
  -M                     Print make rules for generated files.
  --version              Print the version number of flatc and exit.
  --strict-json          Strict JSON: field names must be / will be quoted,
                         no trailing commas in tables/vectors.
  --allow-non-utf8       Pass non-UTF-8 input through parser and emit nonstandard
                         \x escapes in JSON. (Default is to raise parse error on
                         non-UTF-8 input.)
  --natural-utf8         Output strings with UTF-8 as human-readable strings.
                         By default, UTF-8 characters are printed as \uXXXX escapes.
  --defaults-json        Output fields whose value is the default when
                         writing JSON
  --unknown-json         Allow fields in JSON that are not defined in the
                         schema. These fields will be discared when generating
                         binaries.
  --no-prefix            Don't prefix enum values with the enum type in C++.
  --scoped-enums         Use C++11 style scoped and strongly typed enums.
                         also implies --no-prefix.
  --gen-includes         (deprecated), this is the default behavior.
                         If the original behavior is required (no include
                         statements) use --no-includes.
  --no-includes          Don't generate include statements for included
                         schemas the generated file depends on (C++ / Python).
  --gen-mutable          Generate accessors that can mutate buffers in-place.
  --gen-onefile          Generate single output file for C# and Go.
  --gen-name-strings     Generate type name functions for C++ and Rust.
  --gen-object-api       Generate an additional object-based API.
  --gen-compare          Generate operator== for object-based API types.
  --gen-nullable         Add Clang _Nullable for C++ pointer. or @Nullable for Java
  --java-checkerframe    work Add @Pure for Java.
  --gen-generated        Add @Generated annotation for Java
  --gen-all              Generate not just code for the current schema files,
                         but for all files it includes as well.
                         If the language uses a single file for output (by default
                         the case for C++ and JS), all code will end up in this one
                         file.
  --cpp-include          Adds an #include in generated file.
  --cpp-ptr-type T       Set object API pointer type (default std::unique_ptr).
  --cpp-str-type T       Set object API string type (default std::string).
                         T::c_str(), T::length() and T::empty() must be supported.
                         The custom type also needs to be constructible from std::string
                         (see the --cpp-str-flex-ctor option to change this behavior).
  --cpp-str-flex-ctor    Don't construct custom string types by passing std::string
                         from Flatbuffers, but (char* + length).
  --cpp-std CPP_STD      Generate a C++ code using features of selected C++ standard.
                         Supported CPP_STD values:
                          * 'c++0x' - generate code compatible with old compilers;
                          * 'c++11' - use C++11 code generator (default);
                          * 'c++17' - use C++17 features in generated code (experimental).
  --object-prefix        Customise class prefix for C++ object-based API.
  --object-suffix        Customise class suffix for C++ object-based API.
                         Default value is "T".
  --no-js-exports        Removes Node.js style export lines in JS.
  --goog-js-export       Uses goog.exports* for closure compiler exporting in JS.
  --es6-js-export        Uses ECMAScript 6 export style lines in JS.
  --go-namespace         Generate the overrided namespace in Golang.
  --go-import            Generate the overrided import for flatbuffers in Golang
                         (default is "github.com/google/flatbuffers/go").
  --raw-binary           Allow binaries without file_indentifier to be read.
                         This may crash flatc given a mismatched schema.
  --size-prefixed        Input binaries are size prefixed buffers.
  --proto                Input is a .proto, translate to .fbs.
  --proto-namespace-suffix Add this namespace to any flatbuffers generated
    SUFFIX                 from protobufs.
  --oneof-union          Translate .proto oneofs to flatbuffer unions.
  --grpc                 Generate GRPC interfaces for the specified languages.
  --schema               Serialize schemas instead of JSON (use with -b).
  --bfbs-comments        Add doc comments to the binary schema files.
  --bfbs-builtins        Add builtin attributes to the binary schema files.
  --bfbs-gen-embed       Generate code to embed the bfbs schema to the source.
  --conform FILE         Specify a schema the following schemas should be
                         an evolution of. Gives errors if not.
  --conform-includes     Include path for the schema given with --conform PATH
  --filename-suffix      The suffix appended to the generated file names.
                         Default is '_generated'.
  --filename-ext         The extension appended to the generated file names.
                         Default is language-specific (e.g., '.h' for C++)
  --include-prefix       Prefix this path to any generated include statements.
    PATH
  --keep-prefix          Keep original prefix of schema include statement.
  --no-fb-import         Don't include flatbuffers import statement for TypeScript.
  --no-ts-reexport       Don't re-export imported dependencies for TypeScript.
  --short-names          Use short function names for JS and TypeScript.
  --reflect-types        Add minimal type reflection to code generation.
  --reflect-names        Add minimal type/name reflection.
  --root-type T          Select or override the default root_type
  --force-defaults       Emit default values in binary output from JSON
  --force-empty          When serializing from object API representation,
                         force strings and vectors to empty rather than null.
  --force-empty-vectors  When serializing from object API representation,
                         force vectors to empty rather than null.
  --flexbuffers          Used with "binary" and "json" options, it generates
                         data using schema-less FlexBuffers.
FILEs may be schemas (must end in .fbs), binary schemas (must end in .bfbs),
or JSON files (conforming to preceding schema). FILEs after the -- must be
binary flatbuffer format files.
Output files are named using the base file name of the input,
and written to the current directory or the path given by -o.
example: ./flatc -c -b schema1.fbs schema2.fbs data.json

flatc の使い方のチュートリアルは下記の公式ページに丁寧な説明があります。 ちなみに今回は tfliteJSON へ変換する手順を使用しますが、 flatcC++ GO Java C# Python JavaScript TypeScript Lua Rust PHP Kotlin Dart のソースファイルへの変換も可能です。 凄い。

生成された flatc を作業用フォルダにコピーしておきます。

ワークフォルダの作成とflatcのコピー
$ mkdir ${HOME}/work_tflite
$ cp flatc ${HOME}/work_tflite
$ cd ${HOME}/work_tflite

3-2. schema.fbs の取得

flatctflite をフォーマット変換するためにはスキーマ定義ファイル schema.fbs というファイルが必要です。 TensorFlow の公式サイトからダウンロード可能です。 取得するバージョンにより、tfliteファイル内の構文に対応していないものがあったりしますので、なるべく最新バージョンを指定してダウンロードすることをおすすめします。 今回は v2.4.0 以降で使用可能な schema.fbs を取得します。 古いバージョンの TensorFlow で生成された tfliteファイル でも変換可能ですが、構文不一致によりところどころうまく変換できない場合があります。

schema.fbsの取得
$ wget https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/lite/schema/schema.fbs

3-3. カスタムオペレーション入りのtfliteファイルを取得

今回はカスタムオペレーションが含まれていて扱いにくいtfliteファイル、 MediaPipe の Google Meet Segmentation を使用します。 Google Meet Segmentation のモデルは解像度に応じて3種類リリースされていますが、今回は一番高解像度で高性能な segm_full_v679.tflite を使用します。入力解像度は 144 x 256 です。 といってもファイルサイズが 407 KB しか無い超軽量セグメンテーションモデルです。 ライセンスは Apache2.0 です。

下図は w-okadaさん のリポジトリにコミットされている動作サンプル画像です。 とても 400KB しかないモデルだとは思えない精細さです。
GoogleMeet Sample
様々なJavaScriptベースのデモが公開されていますのでとても参考に成ります。 気になる方は一度覗いてみてはいかがでしょうか?
This repository is the zoo of image processing webworkers for javascript. You can use these workers as npm package.
https://github.com/w-okada/image-analyze-workers

Meet Segmentation の3種類のモデルは下記の表のLinkからダウンロード可能です。

No. ファイル名 解像度 size DLリンク
1 segm_full_v679.tflite 144x256 407KB link
2 segm_lite_v509.tflite 128x128 407KB link
3 segm_lite_v681.tflite 96x160 407KB link

それでは、最も解像度が高いモデルをダウンロードします。

Google Meet Segmentation のtfliteファイル取得
$ wget https://meet.google.com/_/rtcvidproc/release_1wttl/345264209/segm_full_v679.tflite

3-4. ダウンロードしたtfliteファイルの構造を確認

Netron を使用します。 先ほどダウンロードしたtfliteファイルをアップロードするだけで簡単に構造が確認できます。 https://netron.app/
netron
表示されました。 モデルが複雑に見えますが今回注目すべき部分はモデル終端の灰色のオペレーションの部分です。 灰色のオペレーション Convolution2DTransposeBias は標準の TensorFlow Lite では読み込むことができないオペレーションで、実行時にエラーになってしまいます。 これを自力で標準のオペレーションに置き換えてしまおう、ということです。
model1

3-5. tflite の特定のオペレーションから重みを抽出

後ほど消してしまうカスタムオペレーション Convolution2DTransposeBias から重み情報を抽出しておきます。 これは、標準オペレーションに置き換えるプログラムで改めて再利用するためです。 Netron を使用して表示した Dequantize をクリックし、右端に表示されるフロッピーマークを押すとNumpyの形式で重み情報が抽出・保存できます。
weights
こんな感じです。 今の所使用しませんが、最後の変換手順で作成するPythonプログラムが標準オペレーションを生成するときに自動的に読み込むようにします。
weights folder

3-6. tflite を JSON に変換

大したことをする必要はりません。 下記のコマンド一発で変換できます。 segm_full_v679.json という名前のファイルが生成されます。 -o には出力先のパスを指定しますが、今回はカレントフォルダへ出力する指定にしてみました。

tflite to JSON
$ ./flatc -t \
  --strict-json \
  --defaults-json \
  -o . \
  schema.fbs -- segm_full_v679.tflite

3-7. JSONの構造を確認

出力されたJSONは "version" "operator_codes" "subgraphs" "description" "buffers" "metadata_buffer" というセクションで構成されています。 このうちカスタマイズの際に見なければいけないセクションは "operator_codes""subgraphs" です。

3-7-1. operator_codes

tfliteのモデルに含まれているオペレーションの種類が列記されています。

operator_codesセクション
  "operator_codes": [
    {
      "deprecated_builtin_code": 3,
      "version": 1,
      "builtin_code": "CONV_2D"
    },
    {
      "deprecated_builtin_code": 117,
      "version": 1,
      "builtin_code": "HARD_SWISH"
    },
    {
      "deprecated_builtin_code": 21,
      "version": 1,
      "builtin_code": "RELU6"
    },
    {
      "deprecated_builtin_code": 4,
      "version": 1,
      "builtin_code": "DEPTHWISE_CONV_2D"
    },
    {
      "deprecated_builtin_code": 1,
      "version": 1,
      "builtin_code": "AVERAGE_POOL_2D"
    },
    {
      "deprecated_builtin_code": 9,
      "version": 1,
      "builtin_code": "FULLY_CONNECTED"
    },
    {
      "deprecated_builtin_code": 19,
      "version": 1,
      "builtin_code": "RELU"
    },

3-7-2. subgraphs

各オペレーションの定義が列記されています。 subgraphs はさらに tensors inputs outputs operators に細分化されます。

subgraphsセクション
  "subgraphs": [
    {
      "tensors": [
        {
          "shape": [
            1,
            144,
            256,
            3
          ],
          "type": "FLOAT32",
          "buffer": 0,
          "name": "input_1",
          "is_variable": false
        },
        {
          "shape": [
            16,
            3,
            3,
            3
          ],
          "type": "FLOAT16",
          "buffer": 1,
          "name": "conv2d/Kernel",
          "is_variable": false
        },
        {
          "shape": [
            16
          ],
          "type": "FLOAT16",
          "buffer": 2,
          "name": "conv2d/Bias",
          "is_variable": false
        },

3-7-2-1. subgraphs - tensors

tensors は各オペレーションの形状、型、重み(buffersの中から適用する重みの連番)、名前、変数か否か、が定義されています。 形状と名前以外は特に意識しなくても問題ありません。

        {
          "shape": [
            2,
            2,
            2,
            16
          ],
          "type": "FLOAT32",
          "buffer": 0,
          "name": "segment/Kernel_dequantize",
          "is_variable": false
        },

3-7-2-2. subgraphs - inputs

inputs はこのモデルの入り口になるオペレーションの通番が定義されています。 これは tensors セクションに記載されているオペレーションを上から昇順に数えた通番です。 例えば下記の表記になっている場合は 0 番目 のオペレーションをINPUTとして指定しています。 混乱しやすいのは、のちほど登場する operators セクションにも inputs がありますが、 operators 側に表記されている inputs は該当オペレーションの前後の接続情報であり、意味合いが異なる点に注意してください。 モデル全体のINPUT/OUTPUT ≠ 個々のオペレーションのINPUT/OUTPUT

      "inputs": [
        0
      ],

つまり、

  "subgraphs": [
    {
      "tensors": [ <-------------------ココからスタートして数えた通番
        { <----------------------------ココから
          "shape": [
            1,
            144,
            256,
            3
          ],
          "type": "FLOAT32",
          "buffer": 0,
          "name": "input_1",
          "is_variable": false
        }, <---------------------------ココまで
        {
          "shape": [
            16,
            3,
            3,
            3
          ],
          "type": "FLOAT16",
          "buffer": 1,
          "name": "conv2d/Kernel",
          "is_variable": false
        },

の範囲を表しています。 INPUT の形状が [1, 144, 256, 3] であることが分かりますね。

3-7-2-3. subgraphs - outputs

outputs はこのモデルの出口になるオペレーションの通番が定義されています。 これは inputs と同じく tensors セクションに記載されているオペレーションを上から昇順に数えた通番です。 例えば下記の表記になっている場合は 244 番目 のオペレーションをOUTPUTとして指定しています。

      "outputs": [
        244
      ],

3-7-2-4. subgraphs - operators

各オペレーションの接続情報が定義されています。 例えば下図の場合、 2 番 のオペレーションから入力し、 271 番 として出力することを表します。 271 番inputs に指定されたオペレーションがどこかに記載されているはずですので各自探してみましょう。

        {
          "opcode_index": 13,
          "inputs": [
            2
          ],
          "outputs": [
            271
          ],
          "builtin_options_type": "NONE",
          "custom_options_format": "FLEXBUFFERS"
        },

3-8. カスタムオペレーション Convolution2DTransposeBias を削除

本題です。 灰色のカスタムオペレーション Convolution2DTransposeBias をJSONから削除します。 手がかりは 3-6-2-3. subgraphs - outputs で取り上げた outputs244 番 に指定されているオペレーションです。 分かりにくい場合はNetronを使用して表示したオペレーションの詳細情報と突き合わせをしてみてください。
Netron
では、下記の範囲をJSONから削除してしまいます。 注意点は削除したときにJSONの構文規則を崩してしまわないようにすることです。 例えば下記の場合は <-- ココから削除 と記載されている直前に表示されている , (カンマ) を同時に消す必要があります。 なお、消してしまったオペレーションは後ほど標準オペレーションに置き換えて足しますので心配する必要はありません。

          "builtin_options_type": "NONE",
          "custom_options_format": "FLEXBUFFERS"
        }, <------------------------ ココのカンマは削除必要
        { <------------------------- ココから削除
          "opcode_index": 12,
          "inputs": [
            241,
            353,
            275
          ],
          "outputs": [
            244
          ],
          "builtin_options_type": "NONE",
          "custom_options": [
            1,
            0,
            0,
            0,
            2,
            0,
            0,
            0,
            2,
            0,
            0,
            0
          ],
          "custom_options_format": "FLEXBUFFERS"
        } <------------------------- ココまで削除
      ],

3-9. 出力オペレーションを変更

モデルの終端のカスタムオペレーションを削除してしまいましたので、 3-6-2-3. subgraphs - outputs に記載されていた番号とOUTPUTの位置がズレてしまいました。 このタイミングで調整しておく必要があります。 出力オペレーションが1つだったところが3つに増えましたね。 これは先ほど削除してしまったカスタムオペレーション Convolution2DTransposeBiasinputs に指定されていたオペレーションの通番を指定したためです。 終端のオペレーションを削除しましたので、終端から数えて2つ目のオペレーションが最終OUTPUTに代わることはなんとなくイメージができますね。

変更前
      "outputs": [
        244
      ],

変更後
      "outputs": [
        241,
        353,
        275
      ],

3-10. JSON を tflite に変換

では想定通り終端のカスタムオペレーションが削除されたtfliteモデルが生成されるかどうかを確認してみましょう。 3-5. tflite を JSON に変換 の逆操作のコマンドを発行してJSONをtfliteに変換してみます。 これも大したコマンドを発行しなくても大丈夫です。 JSONのファイル名を変更する理由は、JSONのファイル名がtfliteのファイル名にそのまま適用されるのですが、そのままだとダウンロードしてきた元のtfliteファイルとファイル名が重複し上書き更新されてしまうからです。 今回はファイル名末尾に _opt を付加してみました。 -o は出力先のパス、 -b はスキーマファイルのパス、最後のパラメータが tflite へ変換するJSONファイルのパスです。

JSONをtfliteへ変換
$ mv segm_full_v679.json segm_full_v679_opt.json
$ ./flatc -o . -b schema.fbs segm_full_v679_opt.json

私の作業環境のハードコピーですので関係のないファイルが大量に見えてしまっていますが、optなしのtfliteとは別にopt付きのtfliteファイルが生成されているのが確認できますね。
tflite_opt

3-11. カスタムオペレーションを削除したtfliteファイルの構造を確認

再び Netron を使用してモデル終端が想定どおり編集できているかを確認します。 正しく編集できていてtfliteへの再変換も成功しているようですね。
tfliteopt

3-12. カスタムオペレーションを標準オペレーションに置き換えて終端に追加し最適化済みの saved_model と .tflite を生成

ここからはとてもトリッキーな実装でモデルの末尾に標準オペレーションを付加します。 と言っても、私が作成したPythonスクリプトを実行するだけです。 Google Meet 以外のモデルへ適用する場合は置き換えるカスタムオペレーションの種類に応じてロジックを見直す必要があります。 とりあえず今回のモデルの末尾の Convolution2DTransposeBias を標準オペレーションの TransposeConvAdd (BiasAdd) に置き換えて付加するロジックをご紹介します。 大まかなロジックの流れは下記の通りです。 説明に比してロジックが異様に長く感じるかもしれませんが、やっていることはシンプルですので一度辛抱強く眺めてみてください。 オマケの実装として、EdgeTPUモデルへ変換するために必要な最適化モードも実装されています。 optimizing_for_edgetpu_flg = FalseTrue に変更してから実行すると、EdgeTPUモデルへ変換可能な saved_model形式 で構造を見直して変換を掛けます。 と言っても大したことはしていなくて、今回のモデルの場合は Hard-Swish を独自の Hard-Swishもどき に置き換えているということと、 ResizeBilinear を独自の ResizeBilinearもどき に置き換えているというだけです。 この対応をしないと EdgeTPU モデルへのコンパイルは失敗します。

  1. JSONからモデルの構造情報を読み取る
  2. .tfliteからモデルの構造情報を読み取る
  3. 2で読み込んだオペレーションのリストをループ処理しながら順次結合し、TensorFlowのモデルとして再構築する
  4. ループが終わったタイミングで TransposeConvAdd (BiasAdd) を末尾に追加する
tflite to saved_model to tflite
### tf-nightly==2.5.0-dev20210104

### https://google.github.io/flatbuffers/flatbuffers_guide_tutorial.html

#!/usr/bin/env python
# coding: utf-8

import os
import numpy as np
import json
import tensorflow.compat.v1 as tf
import tensorflow as tfv2
import shutil
from pathlib import Path
import pprint

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
schema = "schema.fbs"
binary = "./flatc"
model_path = "segm_full_v679_opt.tflite"
output_pb_path = "segm_full_v679_opt.pb"
output_savedmodel_path = "saved_model"
model_json_path = "segm_full_v679_opt.json"
output_node_names = ['segment']
height = 144
width  = 256

#################################################################
# Change to True when converting to EdgeTPU model.
optimizing_for_edgetpu_flg = False
#################################################################

def gen_model_json():
    if not os.path.exists(model_json_path):
        cmd = (binary + " -t --strict-json --defaults-json -o . {schema} -- {input}".format(input=model_path, schema=schema))
        print("output json command =", cmd)
        os.system(cmd)


def parse_json():
    j = json.load(open(model_json_path))
    op_types = [v['builtin_code'] for v in j['operator_codes']]
    print('op types:', op_types)
    ops = j['subgraphs'][0]['operators']
    print('num of ops:', len(ops))
    return ops, op_types

def optimizing_hardswish_for_edgetp(input_op, name=None):
    ret_op = None
    if not optimizing_for_edgetpu_flg:
        ret_op = input_op * tf.nn.relu6(input_op + 3) * 0.16666667
    else:
        ret_op = input_op * tf.nn.relu6(input_op + 3) * 0.16666666
    return ret_op

def make_graph(ops, op_types, interpreter):
    tensors = {}
    input_details = interpreter.get_input_details()

    print(input_details)
    for input_detail in input_details:
        tensors[input_detail['index']] = tf.placeholder(
            dtype=input_detail['dtype'],
            shape=input_detail['shape'],
            name=input_detail['name'])

    for op in ops:
        print('@@@@@@@@@@@@@@ op:', op)
        op_type = op_types[op['opcode_index']]
        if op_type == 'CONV_2D':
            input_tensor = tensors[op['inputs'][0]]
            weights = tensors[op['inputs'][1]].transpose(1,2,3,0)
            bias = tensors[op['inputs'][2]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            options = op['builtin_options']
            output_tensor = tf.nn.conv2d(
                input_tensor,
                weights,
                strides=[1, options['stride_h'], options['stride_w'], 1],
                padding=options['padding'],
                dilations=[
                    1, options['dilation_h_factor'],
                    options['dilation_w_factor'], 1
                ],
                name=output_detail['name'] + '/conv2d')
            output_tensor = tf.add(
                output_tensor, bias, name=output_detail['name'])

            if output_detail['name'].split('/')[-1]=='Relu6':
                output_tensor = tf.nn.relu6(output_tensor)

            tensors[output_detail['index']] = output_tensor
        elif op_type == 'DEPTHWISE_CONV_2D':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            weights = tensors[op['inputs'][1]].transpose(1,2,3,0)
            bias = tensors[op['inputs'][2]]
            options = op['builtin_options']
            output_tensor = tf.nn.depthwise_conv2d(
                input_tensor,
                weights,
                strides=[1, options['stride_h'], options['stride_w'], 1],
                padding=options['padding'],
                # dilations=[1, options['dilation_h_factor'], options['dilation_w_factor'], 1],
                name=output_detail['name'] + '/depthwise_conv2d')
            output_tensor = tf.add(output_tensor, bias, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'MAX_POOL_2D':
            input_tensor = tensors[op['inputs'][0]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            options = op['builtin_options']
            output_tensor = tf.nn.max_pool(
                input_tensor,
                ksize=[
                    1, options['filter_height'], options['filter_width'], 1
                ],
                strides=[1, options['stride_h'], options['stride_w'], 1],
                padding=options['padding'],
                name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'PAD':
            input_tensor = tensors[op['inputs'][0]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            paddings_detail = interpreter._get_tensor_details(op['inputs'][1])
            paddings_array = interpreter.get_tensor(paddings_detail['index'])
            paddings = tf.Variable(
                paddings_array, name=paddings_detail['name'])
            output_tensor = tf.pad(
                input_tensor, paddings, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'RELU':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            output_tensor = tf.nn.relu(input_tensor, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'PRELU':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            alpha_detail = interpreter._get_tensor_details(op['inputs'][1])
            alpha_array = interpreter.get_tensor(alpha_detail['index'])
            with tf.variable_scope(name_or_scope=output_detail['name']):
                alphas = tf.Variable(alpha_array, name=alpha_detail['name'])
                output_tensor = tf.maximum(alphas * input_tensor, input_tensor)
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'RELU6':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            output_tensor = tf.nn.relu6(input_tensor, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor 
        elif op_type == 'RESHAPE':
            input_tensor = tensors[op['inputs'][0]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            options = op['builtin_options']
            output_tensor = tf.reshape(input_tensor, options['new_shape'], name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'ADD':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor_0 = tensors[op['inputs'][0]]
            try:
                input_tensor_1 = tensors[op['inputs'][1]]
            except:
                param = interpreter._get_tensor_details(op['inputs'][1])
                input_tensor_1 = interpreter.get_tensor(param['index'])
            output_tensor = tf.add(input_tensor_0, input_tensor_1, name=output_detail['name'])

            if output_detail['name'].split('/')[-1]=='Relu6':
                output_tensor = tf.nn.relu6(output_tensor)

            tensors[output_detail['index']] = output_tensor
        elif op_type == 'CONCATENATION':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor_0 = tensors[op['inputs'][0]]
            input_tensor_1 = tensors[op['inputs'][1]]
            try:
                input_tensor_2 = tensors[op['inputs'][2]]
                options = op['builtin_options']
                output_tensor = tf.concat([input_tensor_0, input_tensor_1, input_tensor_2],
                                        options['axis'],
                                        name=output_detail['name'])
            except:
                options = op['builtin_options']
                output_tensor = tf.concat([input_tensor_0, input_tensor_1],
                                        options['axis'],
                                        name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'LOGISTIC':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            output_tensor = tf.math.sigmoid(input_tensor, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'TRANSPOSE_CONV':
            input_tensor = tensors[op['inputs'][2]]
            weights_detail = interpreter._get_tensor_details(op['inputs'][1])
            output_shape_detail = interpreter._get_tensor_details(op['inputs'][0])
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            weights_array = interpreter.get_tensor(weights_detail['index'])
            weights_array = np.transpose(weights_array, (1, 2, 0, 3))
            output_shape_array = interpreter.get_tensor(output_shape_detail['index'])
            weights = tf.Variable(weights_array, name=weights_detail['name'])
            shape = tf.Variable(output_shape_array, name=output_shape_detail['name'])
            options = op['builtin_options']
            output_tensor = tf.nn.conv2d_transpose(input_tensor,
                                                   weights,
                                                   shape,
                                                   [1, options['stride_h'], options['stride_w'], 1],
                                                   padding=options['padding'],
                                                   name=output_detail['name'] + '/conv2d_transpose')
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'MUL':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor_0 = tensors[op['inputs'][0]]
            input_tensor_1 = None
            try:
                input_tensor_1 = tensors[op['inputs'][1]]
            except:
                param = interpreter._get_tensor_details(op['inputs'][1])
                input_tensor_1 = interpreter.get_tensor(param['index'])
            output_tensor = tf.multiply(input_tensor_0, input_tensor_1, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'HARD_SWISH':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            output_tensor = optimizing_hardswish_for_edgetp(input_tensor, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'AVERAGE_POOL_2D':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            options = op['builtin_options']
            pool_size = [options['filter_height'], options['filter_width']]
            strides = [options['stride_h'], options['stride_w']]
            padding = options['padding']
            output_tensor = tf.keras.layers.AveragePooling2D(pool_size=pool_size,
                                                             strides=strides,
                                                             padding=padding,
                                                             name=output_detail['name'])(input_tensor)
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'FULLY_CONNECTED':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            weights = tensors[op['inputs'][1]].transpose(1,0)
            bias = tensors[op['inputs'][2]]
            output_shape_detail = interpreter._get_tensor_details(op['inputs'][0])
            output_shape_array = interpreter.get_tensor(output_shape_detail['index'])

            output_tensor = tf.keras.layers.Dense(units=output_shape_array.shape[3],
                                                  use_bias=True,
                                                  kernel_initializer=tf.keras.initializers.Constant(weights),
                                                  bias_initializer=tf.keras.initializers.Constant(bias))(input_tensor)
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'RESIZE_BILINEAR':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            size_detail = interpreter._get_tensor_details(op['inputs'][1])
            size = interpreter.get_tensor(size_detail['index'])
            size_height = size[0]
            size_width  = size[1]

            def upsampling2d_bilinear(x, size_height, size_width):
                if optimizing_for_edgetpu_flg:
                    return tf.image.resize_bilinear(x, (size_height, size_width))
                else:
                    return tfv2.image.resize(x, [size_height, size_width], method='bilinear')

            output_tensor = tf.keras.layers.Lambda(upsampling2d_bilinear, arguments={'size_height': size_height, 'size_width': size_width})(input_tensor)
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'DEQUANTIZE':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            weights_detail = interpreter._get_tensor_details(op['inputs'][0])
            weights = interpreter.get_tensor(weights_detail['index'])
            output_tensor = weights.astype(np.float32)
            tensors[output_detail['index']] = output_tensor
        else:
            raise ValueError(op_type)

    # Convolution2DTransposeBias
    input_tensor = tensors[241]
    weights = np.load('segment_Kernel').transpose(1,2,0,3).astype(np.float32)
    bias = np.load('segment_Bias').astype(np.float32)
    custom_trans = tf.nn.conv2d_transpose(input=input_tensor,
                                          filters=weights,
                                          output_shape=[1, height, width, 2],
                                          strides=[2, 2],
                                          padding='SAME',
                                          dilations=[1, 1])
    output_tensor = tf.math.add(custom_trans, bias, name='segment')
    tensors[999] = output_tensor


def main():

    tf.disable_eager_execution()

    gen_model_json()
    ops, op_types = parse_json()

    interpreter = tf.lite.Interpreter(model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print(input_details)
    print(output_details)

    make_graph(ops, op_types, interpreter)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    graph = tf.get_default_graph()

    with tf.Session(config=config, graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        graph_def = tf.graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=graph.as_graph_def(),
            output_node_names=output_node_names)

        with tf.io.gfile.GFile(output_pb_path, 'wb') as f:
            f.write(graph_def.SerializeToString())

        shutil.rmtree(output_savedmodel_path, ignore_errors=True)
        tf.saved_model.simple_save(
            sess,
            output_savedmodel_path,
            inputs={'input_1': graph.get_tensor_by_name('input_1:0')},
            outputs={'segment': graph.get_tensor_by_name('segment:0')}
        )

    converter = tfv2.lite.TFLiteConverter.from_saved_model(output_savedmodel_path)
    converter.target_spec.supported_ops = [tfv2.lite.OpsSet.TFLITE_BUILTINS, tfv2.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()
    with open(f'{output_savedmodel_path}/model_float32.tflite', 'wb') as w:
        w.write(tflite_model)

if __name__ == '__main__':
    main()

"""
$ saved_model_cli show --dir saved_model --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 144, 256, 3)
        name: input_1:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['segment'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 144, 256, 2)
        name: segment:0
  Method name is: tensorflow/serving/predict
"""

ご参考までに、最適化後のtfliteモデルの全体像を下図に示します。 どうでしょうか? 限界まで最適化されて美しくなっていますね。 標準オペレーションに置き換えたことにより、ノンカスタムの TensorFlow Lite で普通に実行できるようになっているはずです。 MediaPipe を利用すると爆速パフォーマンスを体験できますのでとてもメリットがあるのですが、実装がややこしいため通常のPythonロジックのみで開発できると選択肢の幅が広がりますね。
finalmodel

4. おわりに

この手順で変換した MediaPipe の各種モデルやスクリプトは全て下記のリポジトリにコミットされています。 tflite だけではなく、 EdgeTPU、TFJS、TF-TRT、CoreML、ONNX、OpenVINO IR の形式へ変換済みです。 ご参考までにどうぞ。

https://github.com/PINTO0309/PINTO_model_zoo
zoo
zoo

Discussion

ログインするとコメントできます