Open1

EdgeTPU最適化のためのArgMax置き換え

PINTOPINTO
import tensorflow as tf
import os
from pprint import pprint
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

dummy_input = np.arange(6).reshape(1,1,2,3).astype(np.float32)
print(dummy_input)

testonly = False

if not testonly:

    # Create a model
    i = tf.keras.layers.Input(
        shape=[
            dummy_input.shape[1],
            dummy_input.shape[2],
            dummy_input.shape[3],
        ],
        batch_size=dummy_input.shape[0]
    )

    def _nnapi_scalar(value, dtype):
        # Resolves "Scalar operand should be constant" at cost of broadcasting
        return tf.constant(value, dtype=dtype, shape=(1,))

    def argmax(
        input_tensor,
        axis=-1,
        output_type = tf.dtypes.float32,
        name = None,
        keepdims = False,
        epsilon = None
    ):
        """Returns the index with the largest value across axes of a tensor.
        Approximately tf.compat.v1.argmax, but not equivalent. If arithmetic allows
        value to be anomalously close to the maximum, but not equal to it, the
        behavior is undefined.
        Args:
        input_tensor: A Tensor.
        axis: A Value. Must be in the range [-rank(input), rank(input)). Describes
            which axis of the input Tensor to reduce across. For vectors, use axis =
            0.
        output_type: An optional tf.DType. Note that default is different from
            tflite (int64) to make default behavior compatible with darwinn.
        name: Optional name for operations.
        keepdims: If true, retains reduced dimensions with length 1.
        epsilon: Optional small number which is intended to be always below
            quantization threshold, used to distinguish equal and not equal numbers.
        Returns:
        A Tensor of type output_type.
        """
        safe_axis = axis
        if safe_axis < 0:
            safe_axis = len(input_tensor.shape) + safe_axis
        reduction_size = input_tensor.shape[axis]
        axis_max = tf.math.reduce_max(input_tensor, axis=axis, keepdims=True)
        zero_if_max = tf.subtract(axis_max, input_tensor)
        eps = epsilon if epsilon else 1e-6
        if input_tensor.dtype.is_floating:
            zero_if_max_else_eps = tf.math.minimum(_nnapi_scalar(eps, input_tensor.dtype), zero_if_max)
            zero_if_max_else_one = zero_if_max_else_eps * _nnapi_scalar(1 / eps, input_tensor.dtype)
        elif input_tensor.dtype.is_integer:
            zero_if_max_else_one = tf.math.minimum(_nnapi_scalar(1, input_tensor.dtype), zero_if_max)
        else:
            raise ValueError('Please specify epsilon for unknown input data type')

        # Input type ends here, output type starts here
        zero_if_max_else_one = tf.cast(zero_if_max_else_one, dtype=output_type)
        zero_if_max_else_one = zero_if_max_else_one
        one_if_max_else_zero = tf.math.subtract(_nnapi_scalar(1, output_type), zero_if_max_else_one)
        rev_index = tf.range(reduction_size, 0, -1, dtype=output_type)
        for index in range(safe_axis + 1, len(input_tensor.shape)):
            rev_index = tf.expand_dims(rev_index, axis=index - safe_axis)
        rev_index = rev_index
        rev_index_if_max_else_zero = tf.math.multiply(one_if_max_else_zero, rev_index)
        reverse_argmax = tf.math.reduce_max(rev_index_if_max_else_zero, axis=axis, keepdims=keepdims, name=name)
        # Final operation obtains name if argmax layer if provided
        return tf.math.subtract(_nnapi_scalar(reduction_size, output_type), reverse_argmax, name=name)

    a = tf.argmax(i, axis=3)
    b = argmax(input_tensor=i, axis=3, keepdims=False)

    model = tf.keras.models.Model(inputs=i, outputs=a)
    model.summary()
    output_path = 'saved_model'
    tf.saved_model.save(model, output_path)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite_model = converter.convert()
    open(f"{output_path}/test_tf_argmax.tflite", "wb").write(tflite_model)

    model = tf.keras.models.Model(inputs=i, outputs=b)
    model.summary()
    output_path = 'saved_model'
    tf.saved_model.save(model, output_path)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite_model = converter.convert()
    open(f"{output_path}/test_custom_argmax.tflite", "wb").write(tflite_model)

# Float32
interpreter = tf.lite.Interpreter('saved_model/test_tf_argmax.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret = interpreter.get_tensor(output_details[0]['index'])
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ Float32')
pprint(ret)

# Float32
interpreter = tf.lite.Interpreter('saved_model/test_custom_argmax.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret = interpreter.get_tensor(output_details[0]['index'])
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ Float32')
pprint(ret)