Open4
Trial and error to replace tf.gather with primitive OP (tf.gatherをプリミティブなOPへ置き換える試行錯誤)
- tensorflow の tf.gather の実装、テンソルとインデックスを flatten してから処理している模様
実装できた。
import os
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.random.set_seed(0)
######################################################################
### params_rank == 3 test data
input_shape = [30522, 128]
input_data = np.random.random(input_shape).astype(np.float32)
input_data = tf.convert_to_tensor(input_data)
params = tf.keras.Input(
shape=[384, 1],
batch_size=1,
dtype=tf.int32,
)
### params_rank == 2 test data
# input_shape = [2, 512]
# input_data = np.random.random(input_shape).astype(np.float32)
# input_data = tf.convert_to_tensor(input_data)
# params = tf.keras.Input(
# shape=[384],
# batch_size=1,
# dtype=tf.int32,
# )
### params_rank == 1 test data
# input_shape = [2, 512]
# input_data = np.random.random(input_shape).astype(np.float32)
# input_data = tf.convert_to_tensor(input_data)
# params = tf.keras.Input(
# shape=[384],
# batch_size=1,
# dtype=tf.int32,
# )
# params = tf.squeeze(params, axis=0)
######################################################################
gather_axis=0
gathered_tensor = tf.gather(
params=input_data,
indices=params,
axis=gather_axis,
)
inputs1 = [params]
outputs1 = [gathered_tensor]
model1 = tf.keras.Model(inputs=inputs1, outputs=outputs1)
run_model = tf.function(lambda *inputs1 : model1(inputs1))
concrete_func = run_model.get_concrete_function(
*[tf.TensorSpec(tensor.shape, tensor.dtype) for tensor in model1.inputs]
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_func]
)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
with open(f'model_float32_gathered.tflite', 'wb') as w:
w.write(tflite_model)
#################################################################
gather_axis=0
input_shape_rank = len(input_shape)
params_shape = list(params.shape)
params_dtype = params.dtype
params_rank = len(params_shape)
target_gathered_shape = []
for idx, s in enumerate(input_shape):
if idx != gather_axis:
target_gathered_shape = target_gathered_shape + [s]
else:
target_gathered_shape = target_gathered_shape + params_shape
flattened_params = tf.reshape(
tensor=params,
shape=[np.prod(params_shape)],
)
# Workaround for GPU Delegate
if 3 <= input_shape_rank <= 4:
# tf.strided_slice supports 3-D and 4-D tensor slices
pass
elif input_shape_rank < 3:
# tf.strided_slice does not support slices of tensors less than 3-D,
# so it is forced to extend to 3-D
expand_shape = [1 for _ in range(3 - input_shape_rank)] + input_shape
input_data = tf.reshape(
tensor=input_data,
shape=expand_shape,
)
gather_axis = gather_axis + (3 - input_shape_rank)
input_shape = input_data.shape
input_shape_rank = len(input_shape)
else:
# GPU Delegate unsupported error
pass
indices = None
if params_rank == 1:
i_list = []
for i in range(params_shape[0]):
idx = i
i_list.append(flattened_params[idx:idx+1])
indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])
elif params_rank == 2:
i_list = []
for i in range(params_shape[0]):
j_list = []
for j in range(params_shape[1]):
idx = i*params_shape[1]+j
j_list.append(flattened_params[idx:idx+1])
i_list.append(tf.concat(j_list, axis=0))
indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])
elif params_rank == 3:
i_list = []
for i in range(params_shape[0]):
j_list = []
for j in range(params_shape[1]):
k_list = []
for k in range(params_shape[2]):
idx = \
i*params_shape[1]*params_shape[2] \
+j*params_shape[2] \
+k
k_list.append(flattened_params[idx:idx+1])
j_list.append(tf.concat(k_list, axis=0))
i_list.append(tf.concat(j_list, axis=0))
indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])
elif params_rank == 4:
i_list = []
for i in range(params_shape[0]):
j_list = []
for j in range(params_shape[1]):
k_list = []
for k in range(params_shape[2]):
l_list = []
for l in range(params_shape[3]):
idx = \
i*params_shape[1]*params_shape[2]*params_shape[3] \
+j*params_shape[2]*params_shape[3] \
+k*params_shape[3] \
+l
l_list.append(flattened_params[idx:idx+1])
k_list.append(tf.concat(l_list, axis=0))
j_list.append(tf.concat(k_list, axis=0))
i_list.append(tf.concat(j_list, axis=0))
indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])
elif params_rank == 5:
i_list = []
for i in range(params_shape[0]):
j_list = []
for j in range(params_shape[1]):
k_list = []
for k in range(params_shape[2]):
l_list = []
for l in range(params_shape[3]):
m_list = []
for m in range(params_shape[3]):
idx = \
i*params_shape[1]*params_shape[2]*params_shape[3]*params_shape[4] \
+j*params_shape[2]*params_shape[3]*params_shape[4] \
+k*params_shape[3]*params_shape[4] \
+l*params_shape[4] \
+m
m_list.append(flattened_params[idx:idx+1])
l_list.append(tf.concat(m_list, axis=0))
k_list.append(tf.concat(l_list, axis=0))
j_list.append(tf.concat(k_list, axis=0))
i_list.append(tf.concat(j_list, axis=0))
indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])
else:
# not yet implemented error
pass
# Slicing
sliced_input_datas = []
for idx in indices:
"""
paddings:
[1,2,3], axis:0 -> [[0,2]]
[1,2,3], axis:1 -> [[1,1]]
[1,2,3], axis:2 -> [[2,0]]
[1,2,3,4,5], axis:0 -> [[0,4]]
[1,2,3,4,5], axis:1 -> [[1,3]]
[1,2,3,4,5], axis:2 -> [[2,2]]
[1,2,3,4,5], axis:3 -> [[3,1]]
[1,2,3,4,5], axis:4 -> [[4,0]]
"""
padded_begin = tf.pad(
tensor=idx,
paddings=[[gather_axis,input_shape_rank-gather_axis-1]],
)
padded_end = tf.pad(
tensor=idx+1,
paddings=[[gather_axis,input_shape_rank-gather_axis-1]],
)
mask = sum([2**dim for dim in range(input_shape_rank)]) - 2**gather_axis
sliced_input_data = tf.strided_slice(
input_=input_data,
begin=padded_begin,
end=padded_end,
begin_mask=mask,
end_mask=mask,
)
sliced_input_datas.append(sliced_input_data)
sliced_concated_input_datas = tf.concat(
values=sliced_input_datas,
axis=gather_axis,
)
gathered_input_datas = tf.reshape(
tensor=sliced_concated_input_datas,
shape=target_gathered_shape,
)
inputs2 = [params]
outputs2 = [gathered_input_datas]
model2 = tf.keras.Model(inputs=inputs2, outputs=outputs2)
run_model = tf.function(lambda *inputs2 : model2(inputs2))
concrete_func = run_model.get_concrete_function(
*[tf.TensorSpec(tensor.shape, tensor.dtype) for tensor in model2.inputs]
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_func]
)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
with open(f'model_float32_gather_stridedslice.tflite', 'wb') as w:
w.write(tflite_model)
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
with open(f'model_float16_gather_stridedslice.tflite', 'wb') as w:
w.write(tflite_model)
tf.lite.experimental.Analyzer.analyze(
model_content=tflite_model,
gpu_compatibility=True,
)
######################################################################
### params_rank == 3 test data
val_data = np.random.randint(0, 100, params_shape, dtype=np.int32)
### params_rank == 2 test data
# val_data = np.random.randint(0, 1, params_shape, dtype=np.int32)
### params_rank == 1 test data
# val_data = np.random.randint(0, 1, params_shape, dtype=np.int32)
######################################################################
model1_outputs = model1(val_data)
model2_outputs = model2(val_data)
print(f'{np.allclose(model1_outputs, model2_outputs)}')
True