The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?
Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
My Data:
I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).
The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.
data train 0 3.png 5.png 13.png 23.png ... 1 3.png 10.png 11.png ... 2 4.png 13.png ... 3 8.png ... 4 ... 5 ... 6 ... 7 ... 8 ... 9 ...
Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.
A generic data loader where the images are arranged in this way by default: root/dog/xxx. png root/dog/xxy.
Here's what I did for pytorch 0.4.1 (should still work in 1.3)
def load_dataset(): data_path = 'data/train/' train_dataset = torchvision.datasets.ImageFolder( root=data_path, transform=torchvision.transforms.ToTensor() ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=64, num_workers=0, shuffle=True ) return train_loader for batch_idx, (data, target) in enumerate(load_dataset()): #train network
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With