Open4

Trial and error to replace tf.gather with primitive OP (tf.gatherをプリミティブなOPへ置き換える試行錯誤)

PINTOPINTO

実装できた。

import os
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.random.set_seed(0)


######################################################################
### params_rank == 3 test data
input_shape = [30522, 128]
input_data = np.random.random(input_shape).astype(np.float32)
input_data = tf.convert_to_tensor(input_data)
params = tf.keras.Input(
    shape=[384, 1],
    batch_size=1,
    dtype=tf.int32,
)

### params_rank == 2 test data
# input_shape = [2, 512]
# input_data = np.random.random(input_shape).astype(np.float32)
# input_data = tf.convert_to_tensor(input_data)
# params = tf.keras.Input(
#     shape=[384],
#     batch_size=1,
#     dtype=tf.int32,
# )

### params_rank == 1 test data
# input_shape = [2, 512]
# input_data = np.random.random(input_shape).astype(np.float32)
# input_data = tf.convert_to_tensor(input_data)
# params = tf.keras.Input(
#     shape=[384],
#     batch_size=1,
#     dtype=tf.int32,
# )
# params = tf.squeeze(params, axis=0)
######################################################################

gather_axis=0

gathered_tensor = tf.gather(
    params=input_data,
    indices=params,
    axis=gather_axis,
)

inputs1 = [params]
outputs1 = [gathered_tensor]
model1 = tf.keras.Model(inputs=inputs1, outputs=outputs1)
run_model = tf.function(lambda *inputs1 : model1(inputs1))
concrete_func = run_model.get_concrete_function(
    *[tf.TensorSpec(tensor.shape, tensor.dtype) for tensor in model1.inputs]
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [concrete_func]
)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
with open(f'model_float32_gathered.tflite', 'wb') as w:
    w.write(tflite_model)

#################################################################

gather_axis=0
input_shape_rank = len(input_shape)
params_shape = list(params.shape)
params_dtype = params.dtype
params_rank = len(params_shape)
target_gathered_shape = []
for idx, s in enumerate(input_shape):
    if idx != gather_axis:
        target_gathered_shape = target_gathered_shape + [s]
    else:
        target_gathered_shape = target_gathered_shape + params_shape
flattened_params = tf.reshape(
    tensor=params,
    shape=[np.prod(params_shape)],
)

# Workaround for GPU Delegate
if 3 <= input_shape_rank <= 4:
    # tf.strided_slice supports 3-D and 4-D tensor slices
    pass
elif input_shape_rank < 3:
    # tf.strided_slice does not support slices of tensors less than 3-D,
    # so it is forced to extend to 3-D
    expand_shape = [1 for _ in range(3 - input_shape_rank)] + input_shape
    input_data = tf.reshape(
        tensor=input_data,
        shape=expand_shape,
    )
    gather_axis = gather_axis + (3 - input_shape_rank)
    input_shape = input_data.shape
    input_shape_rank = len(input_shape)
else:
    # GPU Delegate unsupported error
    pass

indices = None
if params_rank == 1:
    i_list = []
    for i in range(params_shape[0]):
        idx = i
        i_list.append(flattened_params[idx:idx+1])
    indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])
elif params_rank == 2:
    i_list = []
    for i in range(params_shape[0]):
        j_list = []
        for j in range(params_shape[1]):
            idx = i*params_shape[1]+j
            j_list.append(flattened_params[idx:idx+1])
        i_list.append(tf.concat(j_list, axis=0))
    indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])

elif params_rank == 3:
    i_list = []
    for i in range(params_shape[0]):
        j_list = []
        for j in range(params_shape[1]):
            k_list = []
            for k in range(params_shape[2]):
                idx = \
                    i*params_shape[1]*params_shape[2] \
                    +j*params_shape[2] \
                    +k
                k_list.append(flattened_params[idx:idx+1])
            j_list.append(tf.concat(k_list, axis=0))
        i_list.append(tf.concat(j_list, axis=0))
    indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])

elif params_rank == 4:
    i_list = []
    for i in range(params_shape[0]):
        j_list = []
        for j in range(params_shape[1]):
            k_list = []
            for k in range(params_shape[2]):
                l_list = []
                for l in range(params_shape[3]):
                    idx = \
                        i*params_shape[1]*params_shape[2]*params_shape[3] \
                        +j*params_shape[2]*params_shape[3] \
                        +k*params_shape[3] \
                        +l
                    l_list.append(flattened_params[idx:idx+1])
                k_list.append(tf.concat(l_list, axis=0))
            j_list.append(tf.concat(k_list, axis=0))
        i_list.append(tf.concat(j_list, axis=0))
    indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])

elif params_rank == 5:
    i_list = []
    for i in range(params_shape[0]):
        j_list = []
        for j in range(params_shape[1]):
            k_list = []
            for k in range(params_shape[2]):
                l_list = []
                for l in range(params_shape[3]):
                    m_list = []
                    for m in range(params_shape[3]):
                        idx = \
                            i*params_shape[1]*params_shape[2]*params_shape[3]*params_shape[4] \
                            +j*params_shape[2]*params_shape[3]*params_shape[4] \
                            +k*params_shape[3]*params_shape[4] \
                            +l*params_shape[4] \
                            +m
                        m_list.append(flattened_params[idx:idx+1])
                    l_list.append(tf.concat(m_list, axis=0))
                k_list.append(tf.concat(l_list, axis=0))
            j_list.append(tf.concat(k_list, axis=0))
        i_list.append(tf.concat(j_list, axis=0))
    indices = tf.reshape(tensor=tf.concat(i_list, axis=0), shape=[-1,1])

else:
    # not yet implemented error
    pass

# Slicing
sliced_input_datas = []
for idx in indices:
    """
    paddings:
        [1,2,3], axis:0 -> [[0,2]]
        [1,2,3], axis:1 -> [[1,1]]
        [1,2,3], axis:2 -> [[2,0]]

        [1,2,3,4,5], axis:0 -> [[0,4]]
        [1,2,3,4,5], axis:1 -> [[1,3]]
        [1,2,3,4,5], axis:2 -> [[2,2]]
        [1,2,3,4,5], axis:3 -> [[3,1]]
        [1,2,3,4,5], axis:4 -> [[4,0]]
    """
    padded_begin = tf.pad(
        tensor=idx,
        paddings=[[gather_axis,input_shape_rank-gather_axis-1]],
    )
    padded_end = tf.pad(
        tensor=idx+1,
        paddings=[[gather_axis,input_shape_rank-gather_axis-1]],
    )
    mask = sum([2**dim for dim in range(input_shape_rank)]) - 2**gather_axis
    sliced_input_data = tf.strided_slice(
        input_=input_data,
        begin=padded_begin,
        end=padded_end,
        begin_mask=mask,
        end_mask=mask,
    )
    sliced_input_datas.append(sliced_input_data)
sliced_concated_input_datas = tf.concat(
    values=sliced_input_datas,
    axis=gather_axis,
)
gathered_input_datas = tf.reshape(
    tensor=sliced_concated_input_datas,
    shape=target_gathered_shape,
)


inputs2 = [params]
outputs2 = [gathered_input_datas]
model2 = tf.keras.Model(inputs=inputs2, outputs=outputs2)
run_model = tf.function(lambda *inputs2 : model2(inputs2))
concrete_func = run_model.get_concrete_function(
    *[tf.TensorSpec(tensor.shape, tensor.dtype) for tensor in model2.inputs]
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [concrete_func]
)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
with open(f'model_float32_gather_stridedslice.tflite', 'wb') as w:
    w.write(tflite_model)


converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
with open(f'model_float16_gather_stridedslice.tflite', 'wb') as w:
    w.write(tflite_model)

tf.lite.experimental.Analyzer.analyze(
    model_content=tflite_model,
    gpu_compatibility=True,
)


######################################################################
### params_rank == 3 test data
val_data = np.random.randint(0, 100, params_shape, dtype=np.int32)
### params_rank == 2 test data
# val_data = np.random.randint(0, 1, params_shape, dtype=np.int32)
### params_rank == 1 test data
# val_data = np.random.randint(0, 1, params_shape, dtype=np.int32)
######################################################################

model1_outputs = model1(val_data)
model2_outputs = model2(val_data)

print(f'{np.allclose(model1_outputs, model2_outputs)}')