Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using WeightedRandomSampler in PyTorch

I need to implement a multi-label image classification model in PyTorch. However my data is not balanced, so I used the WeightedRandomSampler in PyTorch to create a custom dataloader. But when I iterate through the custom dataloader, I get the error : IndexError: list index out of range

Implemented the following code using this link :https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3?u=surajsubramanian

def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight 
weights = make_weights_for_balanced_classes(train_dataset.imgs, len(full_dataset.classes))
weights = torch.DoubleTensor(weights)
sampler = WeightedRandomSampler(weights, len(weights))

train_loader = DataLoader(train_dataset, batch_size=4,sampler = sampler, pin_memory=True)   

Based on the answer in https://stackoverflow.com/a/60813495/10077354, the following is my updated code. But then too when I create a dataloader :loader = DataLoader(full_dataset, batch_size=4, sampler=sampler), len(loader) returns 1.

class_counts = [1691, 743, 2278, 1271]
num_samples = np.sum(class_counts)
labels = [tag for _,tag in full_dataset.imgs] 

class_weights = [num_samples/class_counts[i] for i in range(len(class_counts)]
weights = [class_weights[labels[i]] for i in range(num_samples)]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), num_samples)

Thanks a lot in advance !

I included an utility function based on the accepted answer below :

def sampler_(dataset):
    dataset_counts = imageCount(dataset)
    num_samples = sum(dataset_counts)
    labels = [tag for _,tag in dataset]

    class_weights = [num_samples/dataset_counts[i] for i in range(n_classes)]
    weights = [class_weights[labels[i]] for i in range(num_samples)]
    sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
    return sampler

The imageCount function finds number of images of each class in the dataset. Each row in the dataset contains the image and the class, so we take the second element in the tuple into consideration.

def imageCount(dataset):
    image_count = [0]*(n_classes)
    for img in dataset:
        image_count[img[1]] += 1
    return image_count
like image 529
Suraj Subramanian Avatar asked Mar 23 '20 10:03

Suraj Subramanian


Video Answer


2 Answers

That code looks a bit complex... You can try the following:

#Let there be 9 samples and 1 sample in class 0 and 1 respectively
class_counts = [9.0, 1.0]
num_samples = sum(class_counts)
labels = [0, 0,..., 0, 1] #corresponding labels of samples

class_weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = [class_weights[labels[i]] for i in range(int(num_samples))]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
like image 115
ccl Avatar answered Oct 12 '22 14:10

ccl


Here is an alternative solution:

import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler

counts = np.bincount(y)
labels_weights = 1. / counts
weights = labels_weights[y]
WeightedRandomSampler(weights, len(weights))

where y is a list of labels corresponding to each sample, has shape (n_samples,) and are encoded [0, ..., n_classes].

weights won't add up to 1, which is ok according to the official docs.

like image 35
Aray Karjauv Avatar answered Oct 12 '22 13:10

Aray Karjauv