Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Unfold and Fold: How do I put this image tensor back together again?

I am trying to filter a single channel 2D image of size 256x256 using unfold to create 16x16 blocks with an overlap of 8. This is shown below:

# I = [256, 256] image
kernel_size = 16
stride = bx/2
patches = I.unfold(1, kernel_size, 
int(stride)).unfold(0, kernel_size, int(stride)) # size = [31, 31, 16, 16]

 

I have started to attempt to put the image back together with fold but I’m not quite there yet. I’ve tried to use view to get the image to ‘fit’ the way it’s supposed to but I don’t see how this would preserve the original image. Perhaps I’m overthinking this.

# patches.shape = [31, 31, 16, 16]
patches = = filt_data_block.contiguous().view(-1, kernel_size*kernel_size) # [961, 256]
patches = patches.permute(1, 0) # size = [951, 256]

Any help would be greatly appreciated. Thanks very much.

like image 864
Bled Clement Avatar asked Sep 24 '20 14:09

Bled Clement


1 Answers

I believe you will benefit from using torch.nn.functional.fold and torch.nn.functional.unfold in this case, as these functions are built specifically for images (or any 4D tensors, that is with shape B X C X H X W).

Let's start with unfolding the image:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.datasets import load_sample_image #Used to load a sample image

dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#Load a flower image from sklearn.datasets, crop it to shape 1 X 3 X 256 X 256:
I = torch.from_numpy(load_sample_image('flower.jpg')).permute(2,0,1).unsqueeze(0).type(dtype)[...,128:128+256,256:256+256]
kernel_size = 16
stride = kernel_size//2 
I_unf = F.unfold(I, kernel_size, stride=stride)

Here we obtain all the 16x16 image patches with strides of 8 by using the F.unfold function. This will result in a 3D tensor with shape torch.Size([1, 768, 961]). ie - 961 patches with 768 = 16 X 16 X 3 pixels within each.

Now, say we wish to fold it back to I:

I_f = F.fold(I_unf,I.shape[-2:],kernel_size,stride=stride)
norm_map = F.fold(F.unfold(torch.ones(I.shape).type(dtype),kernel_size,stride=stride),I.shape[-2:],kernel_size,stride=stride)
I_f /= norm_map

We use F.fold where we tell it the original shape of I, the kernel_size we used to unfold and the stride used. After folding I_unf we will obtain a summation with overlaps. This means that the resulting image will appear saturated. As a result, we need to compute a normalization map which will normalize multiple summation of pixels due to overlaps. A way to do this efficiently is to take a ones tensor and use unfold followed by fold - to mimic the summation with overlaps. This gives us the normalization map by which we normalize I_f to recover I.

Now, we wish to plot I_f and I to prove content is preserved:

#Plot I:
plt.imshow(I[0,...].permute(1,2,0).cpu()/255)

enter image description here

#Plot I_f:
plt.imshow(I_f[0,...].permute(1,2,0).cpu()/255)

enter image description here

This whole process will work also for single-channel images. One thing to notice is that if spatial dimensions of the image are not divisible by the stride, you will get norm_map with zeros (at the edges) due to some pixels not reachable but you can easily handle this case as well.

like image 192
Gil Pinsky Avatar answered Oct 06 '22 03:10

Gil Pinsky