I am trying to create a transform
that shuffles the patches of each image in a batch.
I aim to use it in the same manner as the rest of the transformations in torchvision
:
trans = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ShufflePatches(patch_size=(16,16)) # our new transform
])
More specifically, the input is a BxCxHxW
tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.
Given the image (of size 224x224
):
Using ShufflePatches(patch_size=(112,112))
I would like to produce the output image:
I think the solution has to do with torch.unfold
and torch.fold
, but didn't manage to get any further.
Any help would be appreciated!
Indeed unfold
and fold
seem appropriate in this case.
import torch
import torch.nn.functional as nnf
class ShufflePatches(object):
def __init__(self, patch_size):
self.ps = patch_size
def __call__(self, x):
# divide the batch of images into non-overlapping patches
u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
# permute the patches of each image in the batch
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
# fold the permuted patches back together
f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
return f
Here's an example with patch size=16:
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With