Open1
gather_nd (GatherND) の Unity Barracuda 対応のためのワークアラウンドメモ
def barracuda_gather_nd(params, indices):
if len(indices.shape) == 4 and indices.shape[0] == 1:
indices = indices[0]
elif len(indices.shape) == 3:
pass
else:
print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
sys.exit(-1)
if len(params.shape) == 4 and params.shape[0] == 1:
params = params[0]
elif len(params.shape) == 3:
pass
else:
print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
sys.exit(-1)
idx_shape = indices.shape
params_shape = params.shape
idx_dims = idx_shape[-1]
gather_shape = params_shape[idx_dims:]
params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
axis_step = tf.math.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
mul = tf.math.multiply(indices, axis_step)
indices_flat = tf.reduce_sum(mul, axis=-1)
result_flat = tf.gather(params_flat, indices_flat)
return tf.expand_dims(tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0)), axis=0)
if not optimizing_barracuda:
output_tensor = tf.gather_nd(
input_tensor1,
input_tensor2,
name=get_op_name(output_detail['name'])
)
else:
output_tensor = barracuda_gather_nd(input_tensor1, input_tensor2)