Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Number of instances per class in pytorch dataset

I'm trying to make a simple image classifier using PyTorch. This is how I load the data into a dataset and dataLoader:

batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])

dataset = ImageFolder(data_dir, transform=transform)

indices = list(range(len(dataset)))

train_indices = indices[:int(len(indices)*0.8)] 
test_indices = indices[int(len(indices)*0.8):]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)

I want to print out the number of images in each class in training and test data separately, something like this:

In train data:

  • shoes: 20
  • shirts: 14

In test data:

  • shoes: 4
  • shirts: 3

I tried this:

from collections import Counter
print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))

but I got this error:

AttributeError: 'MyDataset' object has no attribute 'img'
like image 208
Amin Bashiri Avatar asked Jun 11 '20 07:06

Amin Bashiri


People also ask

What is dataset class in PyTorch?

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset and implement functions specific to the particular data. They can be used to prototype and benchmark your model. You can find them here: Image Datasets, Text Datasets, and Audio Datasets.

What is the difference between a PyTorch dataset and a PyTorch dataloader?

Creating a PyTorch Dataset and managing it with Dataloader keeps your data manageable and helps to simplify your machine learning pipeline. a Dataset stores all your data, and Dataloader is can be used to iterate through the data, manage batches, transform the data, and much more.

What is batch size in dataloader PyTorch?

PyTorch dataloader batch size The batch size is equal to the number of samples in the training data. Code: In the following code, we will import the torch module from which we can process the number of samples before the model is updated. datasets = impdataset(1001) is used as a dataset.

What is Imagefolder in PyTorch?

A generic data loader where the images are arranged in this way by default: root/dog/xxx.


1 Answers

You need to use .targets to access the labels of data i.e.

print(dict(Counter(dataset.targets)))

It'll print something like this (e.g. in MNIST dataset):

{5: 5421, 0: 5923, 4: 5842, 1: 6742, 9: 5949, 2: 5958, 3: 6131, 6: 5918, 7: 6265, 8: 5851}

Also, you can use .classes or .class_to_idx to get mapping of label id to classes:

print(dataset.class_to_idx)
{'0 - zero': 0,
 '1 - one': 1,
 '2 - two': 2,
 '3 - three': 3,
 '4 - four': 4,
 '5 - five': 5,
 '6 - six': 6,
 '7 - seven': 7,
 '8 - eight': 8,
 '9 - nine': 9}

Edit: Method 1

From the comments, in order to get class distribution of training and testing set separately, you can simply iterate over subset as below:

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# labels in training set
train_classes = [label for _, label in train_dataset]
Counter(train_classes)
Counter({0: 4757,
         1: 5363,
         2: 4782,
         3: 4874,
         4: 4678,
         5: 4321,
         6: 4747,
         7: 5024,
         8: 4684,
         9: 4770})

Edit (2): Method 2

Since you've a large dataset, and as you said it takes considerable time to iterate over all training set, there is another way:

You can use .indices of subset, which referes to indices in the original dataset selected for subset.

i.e.

train_classes = [dataset.targets[i] for i in train_dataset.indices]
Counter(train_classes) # if doesn' work: Counter(i.item() for i in train_classes)
like image 160
kHarshit Avatar answered Oct 09 '22 21:10

kHarshit