Open3

tf.image.extract_patches の標準レイヤーへの置き換え試行

極めて限定的な状況にしか対応できないので全く使えない

import tensorflow as tf
import numpy as np

# Define some parameters
input_shape = [5, 10, 10, 6]
patch_size = 3 # height/width of the square image patch size

# Generate example input
x_original = tf.cast(tf.reshape(tf.range(np.prod(input_shape)), input_shape), dtype=tf.float32)

x_shape = x_original.shape # Save the original shape
x = tf.transpose(x_original, perm=[0, 3, 1, 2]) # Move feature map dimension next to the batch dimension
x = tf.expand_dims(x, -1) # Add extra channel at the end 

# Create an identity kernel
kernel = tf.reshape(tf.eye(patch_size**2), [patch_size, patch_size, 1, patch_size**2]) # [filter_height, filter_width, in_channels, out_channels]

# Convolve with identity kernel
patches_simulation = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
patches_simulation = tf.transpose(patches_simulation, perm=[0 ,2 ,3, 4, 1]) # Move filter dim to last
patches_simulation_shape = patches_simulation.shape
patches_simulation = tf.reshape(patches_simulation, [patches_simulation_shape[0], patches_simulation_shape[1], patches_simulation_shape[2], -1]) # Merge last two dims into one

# Intended output to compare against
patches = tf.image.extract_patches(x_original, sizes=[1, patch_size, patch_size, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID')

print(f'tf.image.extract_patches shape {patches.shape} simulation shape {patches_simulation.shape} same shape: {patches.shape == patches_simulation.shape}')
print(f'Simulation is correct? {tf.reduce_all(tf.math.equal(patches, patches_simulation)).numpy()}')
def numpy_eip(arr, ksizes, strides, rates, padding):
	sizes = [1, ksizes[1]*rates[1] - (rates[1]-1), ksizes[2]*rates[2] - (rates[2]-1), 1]
	
	if padding == 'SAME':
		extra_i = max(0, (arr.shape[1]-1) // strides[1] * strides[1] + sizes[1] - arr.shape[1])
		extra_j = max(0, (arr.shape[2]-1) // strides[2] * strides[2] + sizes[2] - arr.shape[2])
		arr = np.pad(arr, [(0,0), (extra_i//2, extra_i//2 + extra_i%2), (extra_j//2, extra_j//2 + extra_j%2), (0,0)])
	elif padding != 'VALID':
		raise Exception('Padding type "%s" is not supported' % padding)
	
	def make_range(in_size, k_size, rate, stride):
		return range(0, in_size - (k_size*rate - rate), stride)
	indexes_i = make_range(arr.shape[1], ksizes[1], rates[1], strides[1])
	indexes_j = make_range(arr.shape[2], ksizes[2], rates[2], strides[2])

	batch_size = arr.shape[0]
	channel_size = ksizes[1]*ksizes[2]*arr.shape[3]

	return np.concatenate([np.concatenate([
		arr[:, i : sizes[1]+i : rates[1], j : sizes[2]+j : rates[2], :].reshape([batch_size, 1, 1, channel_size])
			for j in indexes_j], axis=2)
				for i in indexes_i], axis=1)
ログインするとコメントできます