I have an autoencoder that takes an image as an input and produces a new image as an output.
The input image (1x1024x1024x3) is split into patches (1024x32x32x3) before being fed to the network.
Once I have the output, also a batch of patches size 1024x32x32x3, I want to be able to reconstruct a 1024x1024x3 image. I thought I had this sussed by simply reshaping, but here's what happened.
First, the image as read by Tensorflow:
I patched the image with the following code
patch_size = [1, 32, 32, 1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [1024, 32, 32, 3])
Here are a couple of patches from this image:
But it's when I reshape this patch data back into an image that things go pear-shaped.
reconstructed = tf.reshape(patches, [1, 1024, 1024, 3])
converted = tf.image.convert_image_dtype(reconstructed, tf.uint8)
encoded = tf.image.encode_png(converted)
In this example, no processing has been done between patching and reconstructing. I have made a version of the code you can use to test this behaviour. To use it, run the following:
echo "/path/to/test-image.png" > inputs.txt
mkdir images
python3 image_test.py inputs.txt images
The code will make one input image, one patch image, and one output image for each of the 1024 patches in each input image, so comment out the lines that create input and output images if you're only concerned with saving all the patches.
Somebody, please explain what happened :(
Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches
is its gradient, as suggested here. Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:
import tensorflow as tf
from keras import backend as K
import numpy as np
def extract_patches(x):
return tf.extract_image_patches(
x,
(1, 3, 3, 1),
(1, 1, 1, 1),
(1, 1, 1, 1),
padding="VALID"
)
def extract_patches_inverse(x, y):
_x = tf.zeros_like(x)
_y = extract_patches(_x)
grad = tf.gradients(_y, _x)[0]
# Divide by grad, to "average" together the overlapping patches
# otherwise they would simply sum up
return tf.gradients(_y, _x, grad_ys=y)[0] / grad
# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches)
# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())
print (np.sum(np.square(images - images_r)))
# 2.3820458e-11
Use Update#2 - One small example for your task: (TF 1.0)
Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.
import tensorflow as tf
image = tf.constant([[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]]])
patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])
sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)
Output:
[[[ 1][ 2][ 3][ 4]]
[[ 5][ 6][ 7][ 8]]
[[ 9][10][11][12]]
[[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
[[ 5][ 6][ 7][ 8]]
[[ 9][10][11][12]]
[[13][14][15][16]]]
(4,4,1)
#Update - for 3 channels (debugging..) working only for p = sqrt(h)
import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32
image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])
sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)
Output :
(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0
#Update 2
Reconstructing from the output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.
import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128
image = tf.random_normal([1,h,h,c])
# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])
# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])
reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)
sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)
Output:
(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0
You could see other cool tensor transformation functions here: https://www.tensorflow.org/api_guides/python/array_ops
tf.extract_image_patches
is quite difficult to use, as it does a lot of stuff in the background.
If you just need non-overlapping, then it's much easier to write it ourselves.
You can reconstruct the full image by inverting all operations in image_to_patches
.
Code sample (plots original image and patches):
import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt
def image_to_patches(image, patch_height, patch_width):
# resize image so that it's dimensions are dividable by patch_height and patch_width
image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32)
image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32)
height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32)
width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32)
num_rows = height // patch_height
num_cols = width // patch_width
# make zero-padding
image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width))
# get slices along the 0-th axis
image = tf.reshape(image, [num_rows, patch_height, width, -1])
# h/patch_h, w, patch_h, c
image = tf.transpose(image, [0, 2, 1, 3])
# get slices along the 1-st axis
# h/patch_h, w/patch_w, patch_w,patch_h, c
image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1])
# num_patches, patch_w, patch_h, c
image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1])
# num_patches, patch_h, patch_w, c
return tf.transpose(image, [0, 2, 1, 3])
image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg')
print('Original image shape:', image.shape)
tile_size = 200
image = tf.constant(image)
tiles = image_to_patches(image, tile_size, tile_size)
sess = tf.Session()
I, tiles = sess.run([image, tiles])
print(I.shape)
print(tiles.shape)
plt.figure(figsize=(1 * (4 + 1), 5))
plt.subplot(5, 1, 1)
plt.imshow(I)
plt.title('original')
plt.axis('off')
for i, tile in enumerate(tiles):
plt.subplot(5, 5, 5 + 1 + i)
plt.imshow(tile)
plt.title(str(i))
plt.axis('off')
plt.show()
I don't know if the following code is an efficient implementation but it works!
_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch
patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])
rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows]
reconstructed = tf.concat(rows,axis=0)
Tf 2.0 users can use space_to_depth and depth_to_space if you aren't doing overlapping blocks.
I may be a bit late, but since I got it working with TF-2.3
, it might prove useful for others. The following code works for non-overlapping patches - single or multi-channel:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
class PatchesToImage(layers.Layer):
def __init__(self, imgh, imgw, imgc, patsz, is_squeeze=True, **kwargs):
super(PatchesToImage, self).__init__(**kwargs)
self.H = (imgh // patsz) * patsz
self.W = (imgw // patsz) * patsz
self.C = imgc
self.P = patsz
self.is_squeeze = is_squeeze
def call(self, inputs):
bs = tf.shape(inputs)[0]
rows, cols = self.H // self.P, self.W // self.P
patches = tf.reshape(inputs, [bs, rows, cols, -1, self.C])
pats_by_clist = tf.unstack(patches, axis=-1)
def tile_patches(ii):
pats = pats_by_clist[ii]
img = tf.nn.depth_to_space(pats, self.P)
return img
img = tf.map_fn(fn=tile_patches, elems=tf.range(self.C), fn_output_signature=inputs.dtype)
img = tf.squeeze(img, axis=-1)
img = tf.transpose(img, perm=[1,2,3,0])
C = tf.shape(img)[-1]
img = tf.cond(tf.logical_and(tf.constant(self.is_squeeze), C==1),
lambda: tf.squeeze(img, axis=-1), lambda: img)
return img
This code works for your specific case, as well as for cases when the images are square, with a square kernel and the image size is divisible by the kernel size.
I did not test it for other cases.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
size = 1024
k_size = 32
axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))
# Define a placeholder for image (or load it directly if you prefer)
img = tf.placeholder(tf.int32, shape=(1, size, size, 3))
# Extract patches
patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1],
strides=[1, k_size, k_size, 1],
rates=[1, 1, 1, 1], padding='VALID')
# Reconstruct the image back from the patches
# First separate out the channel dimension
reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3))
# Tranpose the axes (I got this axes tuple for transpose via experimentation)
reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
# Reshape back
reconstruct = tf.reshape(reconstruct, (size, size, 3))
im_arr = # load image with shape (size, size, 3)
# Run the operations
with tf.Session() as sess:
ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})
# Plot the reconstructed image to verify
plt.imshow(r)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With