Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you load MNIST images into Pytorch DataLoader?

Tags:

The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?

Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

My Data:

I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).

The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.

data     train         0             3.png             5.png             13.png             23.png             ...         1             3.png             10.png             11.png             ...         2             4.png             13.png             ...         3             8.png             ...         4             ...         5             ...         6             ...         7             ...         8             ...         9             ... 
like image 643
Terry Avatar asked Apr 26 '18 21:04

Terry


People also ask

What is the difference between a PyTorch dataset and a PyTorch DataLoader?

Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

What is ImageFolder in PyTorch?

A generic data loader where the images are arranged in this way by default: root/dog/xxx. png root/dog/xxy.


1 Answers

Here's what I did for pytorch 0.4.1 (should still work in 1.3)

def load_dataset():     data_path = 'data/train/'     train_dataset = torchvision.datasets.ImageFolder(         root=data_path,         transform=torchvision.transforms.ToTensor()     )     train_loader = torch.utils.data.DataLoader(         train_dataset,         batch_size=64,         num_workers=0,         shuffle=True     )     return train_loader  for batch_idx, (data, target) in enumerate(load_dataset()):     #train network 
like image 123
Duane Avatar answered Oct 02 '22 14:10

Duane