I am training a deep learning model in PyTorch for binary classification, and I have a dataset containing unbalanced class proportions. My minority class makes up about 10% of the given observations. To avoid the model learning to just predict the majority class, I want to use the WeightedRandomSampler from torch.utils.data in my DataLoader.
Let's say I have 1000 observations (900 in class 0, 100 in class 1), and a batch size of 100 for my dataloader.
Without weighted random sampling, I would expect each training epoch to consist of 10 batches.
A small snippet of code to use WeightedRandomSampler
First, define the function:
def make_weights_for_balanced_classes(images, nclasses):
n_images = len(images)
count_per_class = [0] * nclasses
for _, image_class in images:
count_per_class[image_class] += 1
weight_per_class = [0.] * nclasses
for i in range(nclasses):
weight_per_class[i] = float(n_images) / float(count_per_class[i])
weights = [0] * n_images
for idx, (image, image_class) in enumerate(images):
weights[idx] = weight_per_class[image_class]
return weights
And after this, use it in the following way:
import torch
dataset_train = datasets.ImageFolder(traindir)
# For unbalanced dataset we create a weighted sampler
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True,
sampler = sampler, num_workers=args.workers, pin_memory=True)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With