I am trying to filter a single channel 2D image of size 256x256 using unfold to create 16x16 blocks with an overlap of 8. This is shown below:
# I = [256, 256] image
kernel_size = 16
stride = bx/2
patches = I.unfold(1, kernel_size,
int(stride)).unfold(0, kernel_size, int(stride)) # size = [31, 31, 16, 16]
I have started to attempt to put the image back together with fold but I’m not quite there yet. I’ve tried to use view to get the image to ‘fit’ the way it’s supposed to but I don’t see how this would preserve the original image. Perhaps I’m overthinking this.
# patches.shape = [31, 31, 16, 16]
patches = = filt_data_block.contiguous().view(-1, kernel_size*kernel_size) # [961, 256]
patches = patches.permute(1, 0) # size = [951, 256]
Any help would be greatly appreciated. Thanks very much.
I believe you will benefit from using torch.nn.functional.fold
and torch.nn.functional.unfold
in this case, as these functions are built specifically for images (or any 4D tensors, that is with shape B X C X H X W).
Let's start with unfolding the image:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.datasets import load_sample_image #Used to load a sample image
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#Load a flower image from sklearn.datasets, crop it to shape 1 X 3 X 256 X 256:
I = torch.from_numpy(load_sample_image('flower.jpg')).permute(2,0,1).unsqueeze(0).type(dtype)[...,128:128+256,256:256+256]
kernel_size = 16
stride = kernel_size//2
I_unf = F.unfold(I, kernel_size, stride=stride)
Here we obtain all the 16x16 image patches with strides of 8 by using the F.unfold
function. This will result in a 3D tensor with shape torch.Size([1, 768, 961])
. ie - 961 patches with 768 = 16 X 16 X 3 pixels within each.
Now, say we wish to fold it back to I:
I_f = F.fold(I_unf,I.shape[-2:],kernel_size,stride=stride)
norm_map = F.fold(F.unfold(torch.ones(I.shape).type(dtype),kernel_size,stride=stride),I.shape[-2:],kernel_size,stride=stride)
I_f /= norm_map
We use F.fold
where we tell it the original shape of I
, the kernel_size
we used to unfold and the stride
used. After folding I_unf
we will obtain a summation with overlaps. This means that the resulting image will appear saturated. As a result, we need to compute a normalization map which will normalize multiple summation of pixels due to overlaps. A way to do this efficiently is to take a ones tensor and use unfold
followed by fold
- to mimic the summation with overlaps. This gives us the normalization map by which we normalize I_f
to recover I
.
Now, we wish to plot I_f
and I
to prove content is preserved:
#Plot I:
plt.imshow(I[0,...].permute(1,2,0).cpu()/255)
#Plot I_f:
plt.imshow(I_f[0,...].permute(1,2,0).cpu()/255)
This whole process will work also for single-channel images. One thing to notice is that if spatial dimensions of the image are not divisible by the stride, you will get norm_map
with zeros (at the edges) due to some pixels not reachable but you can easily handle this case as well.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With