Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reconstructing an image after using extract_image_patches

I have an autoencoder that takes an image as an input and produces a new image as an output.

The input image (1x1024x1024x3) is split into patches (1024x32x32x3) before being fed to the network.

Once I have the output, also a batch of patches size 1024x32x32x3, I want to be able to reconstruct a 1024x1024x3 image. I thought I had this sussed by simply reshaping, but here's what happened.

First, the image as read by Tensorflow: Input image

I patched the image with the following code

patch_size = [1, 32, 32, 1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [1024, 32, 32, 3])

Here are a couple of patches from this image:

Patched input #168Patched input #169

But it's when I reshape this patch data back into an image that things go pear-shaped.

reconstructed = tf.reshape(patches, [1, 1024, 1024, 3])
converted = tf.image.convert_image_dtype(reconstructed, tf.uint8)
encoded = tf.image.encode_png(converted)

Reconstructed output

In this example, no processing has been done between patching and reconstructing. I have made a version of the code you can use to test this behaviour. To use it, run the following:

echo "/path/to/test-image.png" > inputs.txt
mkdir images
python3 image_test.py inputs.txt images

The code will make one input image, one patch image, and one output image for each of the 1024 patches in each input image, so comment out the lines that create input and output images if you're only concerned with saving all the patches.

Somebody, please explain what happened :(

like image 618
Chris Watts Avatar asked May 18 '17 12:05

Chris Watts


7 Answers

Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches is its gradient, as suggested here. Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:

import tensorflow as tf
from keras import backend as K
import numpy as np

def extract_patches(x):
    return tf.extract_image_patches(
        x,
        (1, 3, 3, 1),
        (1, 1, 1, 1),
        (1, 1, 1, 1),
        padding="VALID"
    )

def extract_patches_inverse(x, y):
    _x = tf.zeros_like(x)
    _y = extract_patches(_x)
    grad = tf.gradients(_y, _x)[0]
    # Divide by grad, to "average" together the overlapping patches
    # otherwise they would simply sum up
    return tf.gradients(_y, _x, grad_ys=y)[0] / grad

# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches) 

# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())

print (np.sum(np.square(images - images_r))) 
# 2.3820458e-11
like image 109
Marco Ancona Avatar answered Sep 24 '22 00:09

Marco Ancona


Use Update#2 - One small example for your task: (TF 1.0)

Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.

import tensorflow as tf
image = tf.constant([[[1],   [2],  [3],  [4]],
                 [[5],   [6],  [7],  [8]],
                 [[9],  [10], [11],  [12]],
                [[13], [14], [15],  [16]]])

patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)

Output:

[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4,4,1)

#Update - for 3 channels (debugging..) working only for p = sqrt(h)

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32

image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
   patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

Output :

(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0

#Update 2

Reconstructing from the output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128


image = tf.random_normal([1,h,h,c])

# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])

# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])

reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)

sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

Output:

(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0

You could see other cool tensor transformation functions here: https://www.tensorflow.org/api_guides/python/array_ops

like image 43
Harsha Pokkalla Avatar answered Sep 24 '22 00:09

Harsha Pokkalla


tf.extract_image_patches is quite difficult to use, as it does a lot of stuff in the background.

If you just need non-overlapping, then it's much easier to write it ourselves. You can reconstruct the full image by inverting all operations in image_to_patches.

Code sample (plots original image and patches):

import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt


def image_to_patches(image, patch_height, patch_width):
    # resize image so that it's dimensions are dividable by patch_height and patch_width
    image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32)
    image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32)
    height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32)
    width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32)

    num_rows = height // patch_height
    num_cols = width // patch_width
    # make zero-padding
    image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width))

    # get slices along the 0-th axis
    image = tf.reshape(image, [num_rows, patch_height, width, -1])
    # h/patch_h, w, patch_h, c
    image = tf.transpose(image, [0, 2, 1, 3])
    # get slices along the 1-st axis
    # h/patch_h, w/patch_w, patch_w,patch_h, c
    image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1])
    # num_patches, patch_w, patch_h, c
    image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1])
    # num_patches, patch_h, patch_w, c
    return tf.transpose(image, [0, 2, 1, 3])


image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg')
print('Original image shape:', image.shape)
tile_size = 200
image = tf.constant(image)
tiles = image_to_patches(image, tile_size, tile_size)

sess = tf.Session()
I, tiles = sess.run([image, tiles])
print(I.shape)
print(tiles.shape)


plt.figure(figsize=(1 * (4 + 1), 5))
plt.subplot(5, 1, 1)
plt.imshow(I)
plt.title('original')
plt.axis('off')
for i, tile in enumerate(tiles):
    plt.subplot(5, 5, 5 + 1 + i)
    plt.imshow(tile)
    plt.title(str(i))
    plt.axis('off')
plt.show()
like image 27
Temak Avatar answered Sep 22 '22 00:09

Temak


I don't know if the following code is an efficient implementation but it works!

_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch

patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])

rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows] 

reconstructed = tf.concat(rows,axis=0)
like image 20
51616 Avatar answered Sep 23 '22 00:09

51616


Tf 2.0 users can use space_to_depth and depth_to_space if you aren't doing overlapping blocks.

like image 23
Avedis Avatar answered Sep 22 '22 00:09

Avedis


I may be a bit late, but since I got it working with TF-2.3, it might prove useful for others. The following code works for non-overlapping patches - single or multi-channel:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers

class PatchesToImage(layers.Layer):
    def __init__(self, imgh, imgw, imgc, patsz, is_squeeze=True, **kwargs):
        super(PatchesToImage, self).__init__(**kwargs)
        self.H = (imgh // patsz) * patsz
        self.W = (imgw // patsz) * patsz
        self.C = imgc
        self.P = patsz
        self.is_squeeze = is_squeeze
        
    def call(self, inputs):
        bs = tf.shape(inputs)[0]
        rows, cols = self.H // self.P, self.W // self.P
        patches = tf.reshape(inputs, [bs, rows, cols, -1, self.C])
        pats_by_clist = tf.unstack(patches, axis=-1)
        def tile_patches(ii):
            pats = pats_by_clist[ii]
            img = tf.nn.depth_to_space(pats, self.P)
            return img 
        img = tf.map_fn(fn=tile_patches, elems=tf.range(self.C), fn_output_signature=inputs.dtype)
        img = tf.squeeze(img, axis=-1)
        img = tf.transpose(img, perm=[1,2,3,0])
        C = tf.shape(img)[-1]
        img = tf.cond(tf.logical_and(tf.constant(self.is_squeeze), C==1), 
                      lambda: tf.squeeze(img, axis=-1), lambda: img)
        return img
like image 30
Solitary Angler Avatar answered Sep 23 '22 00:09

Solitary Angler


This code works for your specific case, as well as for cases when the images are square, with a square kernel and the image size is divisible by the kernel size.

I did not test it for other cases.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


size = 1024
k_size = 32
axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))

# Define a placeholder for image (or load it directly if you prefer) 
img = tf.placeholder(tf.int32, shape=(1, size, size, 3))

# Extract patches
patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1], 
                                         strides=[1, k_size, k_size, 1], 
                                         rates=[1, 1, 1, 1], padding='VALID')

# Reconstruct the image back from the patches
# First separate out the channel dimension
reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3)) 
# Tranpose the axes (I got this axes tuple for transpose via experimentation)
reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
# Reshape back
reconstruct = tf.reshape(reconstruct, (size, size, 3))

im_arr = # load image with shape (size, size, 3)

# Run the operations
with tf.Session() as sess:
    ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})

# Plot the reconstructed image to verify
plt.imshow(r)
like image 30
Saqib Shamsi Avatar answered Sep 25 '22 00:09

Saqib Shamsi