Open1

gather_nd (GatherND) の Unity Barracuda 対応のためのワークアラウンドメモ

PINTOPINTO
def barracuda_gather_nd(params, indices):
    if len(indices.shape) == 4 and indices.shape[0] == 1:
        indices = indices[0]
    elif len(indices.shape) == 3:
        pass
    else:
        print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
        print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
        sys.exit(-1)
    if len(params.shape) == 4 and params.shape[0] == 1:
        params = params[0]
    elif len(params.shape) == 3:
        pass
    else:
        print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
        print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
        sys.exit(-1)
    idx_shape = indices.shape
    params_shape = params.shape
    idx_dims = idx_shape[-1]
    gather_shape = params_shape[idx_dims:]
    params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
    axis_step = tf.math.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
    mul = tf.math.multiply(indices, axis_step)
    indices_flat = tf.reduce_sum(mul, axis=-1)
    result_flat = tf.gather(params_flat, indices_flat)
    return tf.expand_dims(tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0)), axis=0)

if not optimizing_barracuda:
    output_tensor = tf.gather_nd(
        input_tensor1,
        input_tensor2,
        name=get_op_name(output_detail['name'])
    )
else:
    output_tensor = barracuda_gather_nd(input_tensor1, input_tensor2)