Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch transforms on TensorDataset

I'm using TensorDataset to create dataset from numpy arrays.

# convert numpy arrays to pytorch tensors
X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])
y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])

# reshape into [C, H, W]
X_train = X_train.reshape((-1, 1, 28, 28)).float()

# create dataset and dataloaders
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)

How do I apply data augmentation (transforms) to TensorDataset?

For example, using ImageFolder, I can specify transforms as one of its parameters torchvision.datasets.ImageFolder(root, transform=...).

According to this reply by one of PyTorch's team members, it's not supported by default. Is there any alternative way to do so?

Feel free to ask if more code is needed to explain the problem.

like image 276
kHarshit Avatar asked Apr 09 '19 08:04

kHarshit


People also ask

What does PyTorch transform do?

Transforms are common image transformations available in the torchvision. transforms module. They can be chained together using Compose . Most transform classes have a function equivalent: functional transforms give fine-grained control over the transformations.

How do I convert a list to a PyTorch tensor?

To convert a Python list to a tensor, we are going to use the tf. convert_to_tensor() function and this function will help the user to convert the given object into a tensor. In this example, the object can be a Python list and by using the function will return a tensor.

What is TensorDataset?

The Dataset class is an abstract class that is used to define new types of (customs) datasets. Instead, the TensorDataset is a ready to use class to represent your data as list of tensors.


1 Answers

By default transforms are not supported for TensorDataset. But we can create our custom class to add that option. But, as I already mentioned, most of transforms are developed for PIL.Image. But anyway here is very simple MNIST example with very dummy transforms. csv file with MNIST here.

Code:

import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# Import mnist dataset from cvs file and convert it to torch tensor

with open('mnist_train.csv', 'r') as f:
    mnist_train = f.readlines()

# Images
X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])
X_train = X_train.reshape((-1, 1, 28, 28))
X_train = torch.tensor(X_train)

# Labels
y_train = np.array([int(i[0]) for i in mnist_train])
y_train = y_train.reshape(y_train.shape[0], 1)
y_train = torch.tensor(y_train)

del mnist_train


class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)


def imshow(img, title=''):
    """Plot the image batch.
    """
    plt.figure(figsize=(10, 10))
    plt.title(title)
    plt.imshow(np.transpose( img.numpy(), (1, 2, 0)), cmap='gray')
    plt.show()


# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=(X_train, y_train), transform=None)
train_loader = torch.utils.data.DataLoader(train_dataset_normal, batch_size=16)

# iterate
for i, data in enumerate(train_loader):
    x, y = data  
    imshow(torchvision.utils.make_grid(x, 4), title='Normal')
    break  # we need just one batch


# Let's add some transforms

# Dataset with flipping tranformations

def vflip(tensor):
    """Flips tensor vertically.
    """
    tensor = tensor.flip(1)
    return tensor


def hflip(tensor):
    """Flips tensor horizontally.
    """
    tensor = tensor.flip(2)
    return tensor


train_dataset_vf = CustomTensorDataset(tensors=(X_train, y_train), transform=vflip)
train_loader = torch.utils.data.DataLoader(train_dataset_vf, batch_size=16)

result = []

for i, data in enumerate(train_loader):
    x, y = data  
    imshow(torchvision.utils.make_grid(x, 4), title='Vertical flip')
    break


train_dataset_hf = CustomTensorDataset(tensors=(X_train, y_train), transform=hflip)
train_loader = torch.utils.data.DataLoader(train_dataset_hf, batch_size=16)

result = []

for i, data in enumerate(train_loader):
    x, y = data  
    imshow(torchvision.utils.make_grid(x, 4), title='Horizontal flip')
    break

Output:

norm vert horz

like image 136
trsvchn Avatar answered Sep 21 '22 21:09

trsvchn