Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get only specific classes from PyTorch's FashionMNIST dataset?

Tags:

python

pytorch

The FashionMNIST dataset has 10 different output classes. How can I get a subset of this dataset with only specific classes? In my case, I only want images of sneaker, pullover, sandal and shirt classes (their classes are 7,2,5 and 6 respectively).

This is how I load my dataset.

train_dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())

The approach I’ve followed is below. Iterate through the dataset, one by one, then compare the 1st element (i.e. class) in the returned tuple to my required class. I’m stuck here. If the value returned is true, how can I append/add this observation to an empty dataset?

sneaker = 0
pullover = 0
sandal = 0
shirt = 0
for i in range(60000):
    if train_dataset_full[i][1] == 7:
        sneaker += 1
    elif train_dataset_full[i][1] == 2:
        pullover += 1
    elif train_dataset_full[i][1] == 5:
        sandal += 1
    elif train_dataset_full[i][1] == 6:
        shirt += 1

Now, in place of sneaker += 1, pullover += 1, sandal += 1 and shirt += 1 I want to do something like this empty_dataset.append(train_dataset_full[i]) or something similar.

If the above approach is incorrect, please suggest another method.

like image 831
Nurav Adnab Avatar asked Sep 20 '20 03:09

Nurav Adnab


1 Answers

Finally found the answer.

dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())
# Selecting classes 7, 2, 5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
like image 134
Nurav Adnab Avatar answered Oct 16 '22 09:10

Nurav Adnab