Open1
EdgeTPU最適化のためのArgMax置き換え
import tensorflow as tf
import os
from pprint import pprint
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
dummy_input = np.arange(6).reshape(1,1,2,3).astype(np.float32)
print(dummy_input)
testonly = False
if not testonly:
# Create a model
i = tf.keras.layers.Input(
shape=[
dummy_input.shape[1],
dummy_input.shape[2],
dummy_input.shape[3],
],
batch_size=dummy_input.shape[0]
)
def _nnapi_scalar(value, dtype):
# Resolves "Scalar operand should be constant" at cost of broadcasting
return tf.constant(value, dtype=dtype, shape=(1,))
def argmax(
input_tensor,
axis=-1,
output_type = tf.dtypes.float32,
name = None,
keepdims = False,
epsilon = None
):
"""Returns the index with the largest value across axes of a tensor.
Approximately tf.compat.v1.argmax, but not equivalent. If arithmetic allows
value to be anomalously close to the maximum, but not equal to it, the
behavior is undefined.
Args:
input_tensor: A Tensor.
axis: A Value. Must be in the range [-rank(input), rank(input)). Describes
which axis of the input Tensor to reduce across. For vectors, use axis =
0.
output_type: An optional tf.DType. Note that default is different from
tflite (int64) to make default behavior compatible with darwinn.
name: Optional name for operations.
keepdims: If true, retains reduced dimensions with length 1.
epsilon: Optional small number which is intended to be always below
quantization threshold, used to distinguish equal and not equal numbers.
Returns:
A Tensor of type output_type.
"""
safe_axis = axis
if safe_axis < 0:
safe_axis = len(input_tensor.shape) + safe_axis
reduction_size = input_tensor.shape[axis]
axis_max = tf.math.reduce_max(input_tensor, axis=axis, keepdims=True)
zero_if_max = tf.subtract(axis_max, input_tensor)
eps = epsilon if epsilon else 1e-6
if input_tensor.dtype.is_floating:
zero_if_max_else_eps = tf.math.minimum(_nnapi_scalar(eps, input_tensor.dtype), zero_if_max)
zero_if_max_else_one = zero_if_max_else_eps * _nnapi_scalar(1 / eps, input_tensor.dtype)
elif input_tensor.dtype.is_integer:
zero_if_max_else_one = tf.math.minimum(_nnapi_scalar(1, input_tensor.dtype), zero_if_max)
else:
raise ValueError('Please specify epsilon for unknown input data type')
# Input type ends here, output type starts here
zero_if_max_else_one = tf.cast(zero_if_max_else_one, dtype=output_type)
zero_if_max_else_one = zero_if_max_else_one
one_if_max_else_zero = tf.math.subtract(_nnapi_scalar(1, output_type), zero_if_max_else_one)
rev_index = tf.range(reduction_size, 0, -1, dtype=output_type)
for index in range(safe_axis + 1, len(input_tensor.shape)):
rev_index = tf.expand_dims(rev_index, axis=index - safe_axis)
rev_index = rev_index
rev_index_if_max_else_zero = tf.math.multiply(one_if_max_else_zero, rev_index)
reverse_argmax = tf.math.reduce_max(rev_index_if_max_else_zero, axis=axis, keepdims=keepdims, name=name)
# Final operation obtains name if argmax layer if provided
return tf.math.subtract(_nnapi_scalar(reduction_size, output_type), reverse_argmax, name=name)
a = tf.argmax(i, axis=3)
b = argmax(input_tensor=i, axis=3, keepdims=False)
model = tf.keras.models.Model(inputs=i, outputs=a)
model.summary()
output_path = 'saved_model'
tf.saved_model.save(model, output_path)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
open(f"{output_path}/test_tf_argmax.tflite", "wb").write(tflite_model)
model = tf.keras.models.Model(inputs=i, outputs=b)
model.summary()
output_path = 'saved_model'
tf.saved_model.save(model, output_path)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
open(f"{output_path}/test_custom_argmax.tflite", "wb").write(tflite_model)
# Float32
interpreter = tf.lite.Interpreter('saved_model/test_tf_argmax.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret = interpreter.get_tensor(output_details[0]['index'])
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ Float32')
pprint(ret)
# Float32
interpreter = tf.lite.Interpreter('saved_model/test_custom_argmax.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret = interpreter.get_tensor(output_details[0]['index'])
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ Float32')
pprint(ret)