Open3

TensorFlowの6D以上のTransposeがTensorFlow LiteでFlexTransposeになってしまう問題のワークアラウンド(5D以下の転置処理への強制置き換え)

PINTOPINTO
  • Numpy - 6次元のTransposeを4次元のTransposeに置き換えて処理する
import os
import math
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import logging
import random
random.seed(0)
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.random.set_seed(0)
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism()
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.FATAL)

###################################################################### パターン1
"""
original_shapes = [1,1,3,48,3,80]
target_perm = [0,1,2,4,3,5]
target_shapes = [1,1,3,3,48,80]
x_shape_one_dims = [0,1]

# 1. target_permから0と1を消す
remove_one_target_perm = [2,4,3,5]

# 2. 小さい数値から順番にゼロからの連番を付与する
replaced_remove_one_target_perm = [0,2,1,3]

# 3. original_shapes を squeeze する
squeezed_original_shapes = [3,48,3,80]

# 4. 3を2のpermで転置する
transposed_4d = [3,3,48,80]

# 5. x_shape_one_dimsの数値がtarget_permのどの位置にあったかを意識してReshapeする
shape_6d = transposed_4d.reshape(target_shapes)

# 6. Fin
shape_6d = [1,1,3,3,48,80]
"""
original_shapes = [1,1,3,48,3,80]
target_perm = [0,1,2,4,3,5]
target_shapes = [1,1,3,3,48,80]

# オリジナルデータ生成
x = np.arange(1, math.prod(original_shapes) + 1).reshape(original_shapes)

# オリジナルデータの形状把握
x_shape = x.shape
# 要素数が1の次元を取得
x_shape_one_dims = [
    idx for idx in range(len(x_shape)) if x_shape[idx]==1
]
# 要素数が1の次元を削除
squeezed_original_x = np.squeeze(x, tuple(x_shape_one_dims))
# 要素数が1の次元を削除した形状を取得
squeezed_original_shapes = squeezed_original_x.shape

# オリジナルデータが5次元以上 なおかつ
# 要素数1の次元を削除したデータが4次元以下のときのみ特殊Transposeを実施
if len(x_shape) >= 5 and len(squeezed_original_shapes) <= 4:
    remove_one_target_perm = [
        idx for idx in target_perm if idx not in x_shape_one_dims
    ]
    sorted_remove_one_target_perm = sorted(remove_one_target_perm)
    replaced_remove_one_target_perm = [
        sorted_remove_one_target_perm.index(idx) \
            for idx in remove_one_target_perm
    ]
    transposed_no_one_data = \
        squeezed_original_x.transpose(replaced_remove_one_target_perm)
    transposed_data = \
        transposed_no_one_data.reshape(target_shapes)

print(f'**** Numpy pattern.1 check: {np.array_equal(transposed_data, x.transpose(target_perm))}')


###################################################################### パターン2
"""
original_shapes = [3,48,1,3,80,1]
target_perm = [0,2,1,3,5,4]
target_shapes = [3,1,48,3,1,80]
x_shape_one_dims = [2,5]

# 1. 2と5を消す
remove_one_target_perm = [0,1,3,4]

# 2. 小さい数値から順番にゼロからの連番を付与する
replaced_remove_one_target_perm = [0,1,2,3]

# 3. original_shapes を squeeze する
squeezed_original_shapes = [3,48,3,80]

# 4. 3を2のpermで転置する
transposed_4d = [3,48,3,80]

# 5. x_shape_one_dimsの数値がtarget_permのどの位置にあったかを意識してReshapeする
shape_6d = transposed_4d.reshape(target_shapes)

# 6. Fin
shape_6d = [3,1,48,3,1,80]
"""
original_shapes = [3,48,1,3,80,1]
target_perm = [0,2,1,3,5,4]
target_shapes = [3,1,48,3,1,80]

# オリジナルデータ生成
x = np.arange(1, math.prod(original_shapes) + 1).reshape(original_shapes)

# オリジナルデータの形状把握
x_shape = x.shape
# 要素数が1の次元を取得
x_shape_one_dims = [
    idx for idx in range(len(x_shape)) if x_shape[idx]==1
]
# 要素数が1の次元を削除
squeezed_original_x = np.squeeze(x, tuple(x_shape_one_dims))
# 要素数が1の次元を削除した形状を取得
squeezed_original_shapes = squeezed_original_x.shape

# オリジナルデータが5次元以上 なおかつ
# 要素数1の次元を削除したデータが4次元以下のときのみ特殊Transposeを実施
if len(x_shape) >= 5 and len(squeezed_original_shapes) <= 4:
    remove_one_target_perm = [
        idx for idx in target_perm if idx not in x_shape_one_dims
    ]
    sorted_remove_one_target_perm = sorted(remove_one_target_perm)
    replaced_remove_one_target_perm = [
        sorted_remove_one_target_perm.index(idx) \
            for idx in remove_one_target_perm
    ]
    transposed_no_one_data = \
        squeezed_original_x.transpose(replaced_remove_one_target_perm)
    transposed_data = \
        transposed_no_one_data.reshape(target_shapes)

print(f'**** Numpy pattern.2 check: {np.array_equal(transposed_data, x.transpose(target_perm))}')
**** Numpy pattern.1 check: True
**** Numpy pattern.2 check: True
PINTOPINTO
  • TensorFlow - 6次元のTransposeを4次元のTransposeに置き換えて処理する
import os
import math
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import logging
import random
random.seed(0)
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.random.set_seed(0)
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism()
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.FATAL)

original_shapes = [3,48,1,3,80,1]
target_perm = [0,2,1,3,5,4]
target_shapes = [3,1,48,3,1,80]

# オリジナルデータ生成
npx = np.arange(1, math.prod(original_shapes) + 1).reshape(original_shapes)
x = tf.constant(npx)

# オリジナルデータの形状把握
x_shape = x.shape
# 要素数が1の次元を取得
x_shape_one_dims = [
    idx for idx in range(len(x_shape)) \
        if isinstance(x_shape[idx], int) and x_shape[idx]==1
]
# 要素数が1の次元を削除
squeezed_original_x = tf.squeeze(x, x_shape_one_dims)
# 要素数が1の次元を削除した形状を取得
squeezed_original_shapes = squeezed_original_x.shape

# オリジナルデータが5次元以上 なおかつ
# 要素数1の次元を削除したデータが4次元以下のときのみ特殊Transposeを実施
if len(x_shape) >= 5 and len(squeezed_original_shapes) <= 4:
    remove_one_target_perm = [
        idx for idx in target_perm if idx not in x_shape_one_dims
    ]
    sorted_remove_one_target_perm = sorted(remove_one_target_perm)
    replaced_remove_one_target_perm = [
        sorted_remove_one_target_perm.index(idx) \
            for idx in remove_one_target_perm
    ]
    transposed_no_one_data = \
        tf.transpose(
            a=squeezed_original_x,
            perm=replaced_remove_one_target_perm,
        )
    transposed_data = \
        tf.reshape(
            tensor=transposed_no_one_data,
            shape=target_shapes,
        )

print(f'**** TF check: {np.array_equal(transposed_data, tf.transpose(x, target_perm))}')
**** TF check: True
PINTOPINTO
  • TensorFlow Lite - 6次元のTransposeを4次元のTransposeに置き換えて処理し、FlexTranspose の生成を抑止する
import os
import math
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import logging
import random
random.seed(0)
import numpy as np
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.random.set_seed(0)
tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism()
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.FATAL)

original_shapes = [3,48,1,3,80,1]
target_perm = [0,2,4,3,5,1]
target_shapes = [3,1,80,3,1,48]

inputs = tf.keras.Input(
    shape=original_shapes[1:],
    batch_size=original_shapes[0] \
        if isinstance(original_shapes[0], int) else None,
    dtype=tf.float32,
)
# オリジナルデータの形状把握
x_shape = inputs.shape
# 要素数が1の次元を取得
x_shape_one_dims = [
    idx for idx in range(len(x_shape)) \
        if isinstance(x_shape[idx], int) and x_shape[idx]==1
]
# 要素数が1の次元を削除
squeezed_original_x = tf.squeeze(inputs, x_shape_one_dims)
# 要素数が1の次元を削除した形状を取得
squeezed_original_shapes = squeezed_original_x.shape

# オリジナルデータが6次元以上 なおかつ
# 要素数1の次元を削除したデータが5次元以下のときのみ特殊Transposeを実施
if len(x_shape) >= 6 and len(squeezed_original_shapes) <= 5:
    remove_one_target_perm = [
        idx for idx in target_perm if idx not in x_shape_one_dims
    ]
    sorted_remove_one_target_perm = sorted(remove_one_target_perm)
    replaced_remove_one_target_perm = [
        sorted_remove_one_target_perm.index(idx) \
            for idx in remove_one_target_perm
    ]
    transposed_no_one_data = \
        tf.transpose(
            a=squeezed_original_x,
            perm=replaced_remove_one_target_perm,
        )
    transposed_data = \
        tf.reshape(
            tensor=transposed_no_one_data,
            shape=target_shapes,
        )

model = tf.keras.Model(inputs=inputs, outputs=transposed_data)
run_model = tf.function(lambda *inputs : model(inputs))
concrete_func = run_model.get_concrete_function(
    *[tf.TensorSpec(tensor.shape, tensor.dtype) for tensor in model.inputs]
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [concrete_func]
)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
with open(f'model_float32.tflite', 'wb') as w:
    w.write(tflite_model)