I'm using TensorDataset
to create dataset from numpy arrays.
# convert numpy arrays to pytorch tensors
X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])
y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])
# reshape into [C, H, W]
X_train = X_train.reshape((-1, 1, 28, 28)).float()
# create dataset and dataloaders
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
How do I apply data augmentation (transforms) to TensorDataset
?
For example, using ImageFolder
, I can specify transforms as one of its parameters torchvision.datasets.ImageFolder(root, transform=...)
.
According to this reply by one of PyTorch's team members, it's not supported by default. Is there any alternative way to do so?
Feel free to ask if more code is needed to explain the problem.
Transforms are common image transformations available in the torchvision. transforms module. They can be chained together using Compose . Most transform classes have a function equivalent: functional transforms give fine-grained control over the transformations.
To convert a Python list to a tensor, we are going to use the tf. convert_to_tensor() function and this function will help the user to convert the given object into a tensor. In this example, the object can be a Python list and by using the function will return a tensor.
The Dataset class is an abstract class that is used to define new types of (customs) datasets. Instead, the TensorDataset is a ready to use class to represent your data as list of tensors.
By default transforms are not supported for TensorDataset
. But we can create our custom class to add that option. But, as I already mentioned, most of transforms are developed for PIL.Image
. But anyway here is very simple MNIST example with very dummy transforms. csv file with MNIST here.
Code:
import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Import mnist dataset from cvs file and convert it to torch tensor
with open('mnist_train.csv', 'r') as f:
mnist_train = f.readlines()
# Images
X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])
X_train = X_train.reshape((-1, 1, 28, 28))
X_train = torch.tensor(X_train)
# Labels
y_train = np.array([int(i[0]) for i in mnist_train])
y_train = y_train.reshape(y_train.shape[0], 1)
y_train = torch.tensor(y_train)
del mnist_train
class CustomTensorDataset(Dataset):
"""TensorDataset with support of transforms.
"""
def __init__(self, tensors, transform=None):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transform = transform
def __getitem__(self, index):
x = self.tensors[0][index]
if self.transform:
x = self.transform(x)
y = self.tensors[1][index]
return x, y
def __len__(self):
return self.tensors[0].size(0)
def imshow(img, title=''):
"""Plot the image batch.
"""
plt.figure(figsize=(10, 10))
plt.title(title)
plt.imshow(np.transpose( img.numpy(), (1, 2, 0)), cmap='gray')
plt.show()
# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=(X_train, y_train), transform=None)
train_loader = torch.utils.data.DataLoader(train_dataset_normal, batch_size=16)
# iterate
for i, data in enumerate(train_loader):
x, y = data
imshow(torchvision.utils.make_grid(x, 4), title='Normal')
break # we need just one batch
# Let's add some transforms
# Dataset with flipping tranformations
def vflip(tensor):
"""Flips tensor vertically.
"""
tensor = tensor.flip(1)
return tensor
def hflip(tensor):
"""Flips tensor horizontally.
"""
tensor = tensor.flip(2)
return tensor
train_dataset_vf = CustomTensorDataset(tensors=(X_train, y_train), transform=vflip)
train_loader = torch.utils.data.DataLoader(train_dataset_vf, batch_size=16)
result = []
for i, data in enumerate(train_loader):
x, y = data
imshow(torchvision.utils.make_grid(x, 4), title='Vertical flip')
break
train_dataset_hf = CustomTensorDataset(tensors=(X_train, y_train), transform=hflip)
train_loader = torch.utils.data.DataLoader(train_dataset_hf, batch_size=16)
result = []
for i, data in enumerate(train_loader):
x, y = data
imshow(torchvision.utils.make_grid(x, 4), title='Horizontal flip')
break
Output:
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With