🤪

TensorFlow 2 Detection Model Zoo モデルのtflite変換

2021/05/01に公開

1. はじめに

TensorFlow 2 Detection Model Zoo というリポジトリがあります。 こちらには TensorFlow v2 ベースのトレーニング済みの様々な物体検出モデルがコミットされているのですが、正規の手順、というか、確立された手順がどこにも言及されていないため、自力で適当に変換しました。手順をシェアします。秒殺で終わるので完全に書き捨ての記事です。
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

2. 手順

トレーニング済みのモデルをダウンロードして解凍します。

$ wget http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v2_512x512_kpts_coco17_tpu-8.tar.gz
$ tar -zxvf centernet_resnet50_v2_512x512_kpts_coco17_tpu-8.tar.gz
$ rm centernet_resnet50_v2_512x512_kpts_coco17_tpu-8.tar.gz
$ cd centernet_resnet50_v2_512x512_kpts_coco17_tpu-8

ダウンロードして解凍したときに同梱されている saved_model の入力と出力の名前や形状を確認します。下記のコマンドを実行します。

$ saved_model_cli show \
  --dir saved_model \
  --tag_set serve \
  --signature_def serving_default

The given SavedModel SignatureDef contains the following input(s):
  inputs['input_tensor'] tensor_info:
      dtype: DT_UINT8
      shape: (1, -1, -1, 3)
      name: serving_default_input_tensor:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['detection_boxes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 4)
      name: StatefulPartitionedCall:0
  outputs['detection_classes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:1
  outputs['detection_keypoint_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 17)
      name: StatefulPartitionedCall:2
  outputs['detection_keypoints'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 17, 2)
      name: StatefulPartitionedCall:3
  outputs['detection_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:4
  outputs['num_detections'] tensor_info:
      dtype: DT_FLOAT
      shape: (1)
      name: StatefulPartitionedCall:5
Method name is: tensorflow/serving/predict

下記のロジックを作成して test.py として保存します。 concrete function というものを利用して saved_model の入力解像度 [1, -1, -1, 3][1, 320, 320, 3] に固定します。変更先の解像度は自由に変更してください。tfliteはバッチサイズ以外の次元が不定の状態 (-1) のままだと変換に失敗します。 なお、入力のTensorが複数ある場合は、 input_shapes = [[1,320,320,3]] の部分を input_shapes = [[1,320,320,3],[1,10,10,1]] のようにリストでつなげて記載するだけです。

### tf-nightly==2.6.0-dev20210430

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph

# Load saved_model and change input shape
# https://github.com/tensorflow/tensorflow/issues/30180#issuecomment-505959220
model = tf.saved_model.load('saved_model')
input_shapes = [[1,320,320,3]]

concrete_func = \
    model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func_input_tensors = \
    [tensor for tensor in concrete_func.inputs if tensor.dtype != tf.resource and not 'unknown' in tensor.name]

for conc_input, def_input in zip(concrete_func_input_tensors, input_shapes):
    print('Before changing the input shape', conc_input)
    conc_input.set_shape(def_input)
    print('After  changing the input shape', conc_input)

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.target_ops = \
    [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("test.tflite", "wb").write(tflite_model)

TensorFlowの環境をクリーニングします。ご自身の手元の環境を壊したくない場合は、こちらのDockerコンテナをご利用ください。 Docker HubからPullして何度でも破壊できる環境がすぐに利用できます。
https://github.com/PINTO0309/tflite2tensorflow

$ sudo pip3 uninstall -y tensorboard-plugin-wit tb-nightly tensorboard \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator

tf-nightly を導入します。

$ sudo pip3 install tf-nightly==2.6.0-dev20210430

変換スクリプトを実行します。

$ python3 test.py

終わり。簡単ですね。 test.tflite というファイルが生成されているのが確認できます。
result

Discussion