Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch Data Augmentation is taking too long

For the task that involves regression, I need to train my models to generate density maps from RGB images. To augment my dataset I have decided to flip all the images horizontally. For that matter, I also have to flip my ground truth images and I did so.

dataset_for_augmentation.listDataset(train_list,
                        shuffle=True,
                        transform=transforms.Compose([
                            transforms.RandomHorizontalFlip(p=1),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                        ]),
                        target_transform=transforms.Compose([
                            transforms.RandomHorizontalFlip(p=1),
                            transforms.ToTensor()
                        ]),
                        train=True,
                        resize=4,
                        batch_size=args.batch_size,
                        num_workers=args.workers),

But here is the problem : For some reason, PyTorch transforms.RandomHorizontalFlip function takes only PIL images (numpy is not allowed) as input. So I decided to convert the type to PIL Image.

img_path = self.lines[index]

img, target = load_data(img_path, self.train, resize=self.resize)

if type(target[0][0]) is np.float64:
    target = np.float32(target)

img = Image.fromarray(img)
target = Image.fromarray(target)

if self.transform is not None:
    img = self.transform(img)
    target = self.target_transform(target)

return img, target

And yes, this operation need enormous amount of time. Considering I need this operation to be carried out for thousands of images, 23 seconds (should have been under half a second at most) per batch is not tolerable.

2019-11-01 16:29:02,497 - INFO - Epoch: [0][0/152]  Time 27.095 (27.095)    Data 23.150 (23.150)    Loss 93.7401 (93.7401)

I would appreciate any suggestions to speed up my augmentation process

like image 944
Bedir Yilmaz Avatar asked Mar 16 '26 10:03

Bedir Yilmaz


1 Answers

You don't need to change the DataLoader to do that. You can use ToPILImage():

transform=transforms.Compose([
    transforms.ToPILImage(),  # check mode assumption in the documentation
    transforms.RandomHorizontalFlip(p=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Anyway, I would avoid converting to PIL. It seems completely unnecessary. If you want to flip all images, then why not to do that using NumPy only?

img_path = self.lines[index]

img, target = load_data(img_path, self.train, resize=self.resize)

if type(target[0][0]) is np.float64:
    target = np.float32(target)

# assuming width axis=1 -- see my comment below
img = np.flip(img, axis=1)
target = np.flip(target, axis=1)

if self.transform is not None:
    img = self.transform(img)
    target = self.target_transform(target)

return img, target

And remove the transforms.RandomHorizontalFlip(p=1) from the Compose. As ToTensor(...) also handles ndarray, you are good to go.

Note: I am assuming the width axis is equal to 1, since ToTensor expects it to be there.

From the docs:

Converts a PIL Image or numpy.ndarray (H x W x C) ...

like image 184
Berriel Avatar answered Mar 19 '26 02:03

Berriel



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!