Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to calculate unbalanced weights for BCEWithLogitsLoss in pytorch

I am trying to solve one multilabel problem with 270 labels and i have converted target labels into one hot encoded form. I am using BCEWithLogitsLoss(). Since training data is unbalanced, I am using pos_weight argument but i am bit confused.

pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.

Do i need to give total count of positive values of each label as a tensor or they mean something else by weights?

like image 720
Naresh Avatar asked Jul 13 '19 17:07

Naresh


3 Answers

The PyTorch documentation for BCEWithLogitsLoss recommends the pos_weight to be a ratio between the negative counts and the positive counts for each class.

So, if len(dataset) is 1000, element 0 of your multihot encoding has 100 positive counts, then element 0 of the pos_weights_vector should be 900/100 = 9. That means that the binary crossent loss will behave as if the dataset contains 900 positive examples instead of 100.

Here is my implementation:

(new, based on this post)

pos_weight = (y==0.).sum()/y.sum()

(original)

  def calculate_pos_weights(class_counts):
    pos_weights = np.ones_like(class_counts)
    neg_counts = [len(data)-pos_count for pos_count in class_counts]
    for cdx, pos_count, neg_count in enumerate(zip(class_counts,  neg_counts)):
      pos_weights[cdx] = neg_count / (pos_count + 1e-5)

    return torch.as_tensor(pos_weights, dtype=torch.float)

Where class_counts is just a column-wise sum of the positive samples. I posted it on the PyTorch forum and one of the PyTorch devs gave it his blessing.

like image 50
crypdick Avatar answered Jan 01 '23 17:01

crypdick


PyTorch solution

Well, actually I have gone through docs and you can simply use pos_weight indeed.

This argument gives weight to positive sample for each class, hence if you have 270 classes you should pass torch.Tensor with shape (270,) defining weight for each class.

Here is marginally modified snippet from documentation:

# 270 classes, batch size = 64    
target = torch.ones([64, 270], dtype=torch.float32)  
# Logits outputted from your network, no activation
output = torch.full([64, 270], 0.9)
# Weights, each being equal to one. You can input your own here.
pos_weight = torch.ones([270])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(0.9))

Self-made solution

When it comes to weighting, there is no built-in solution, but you may code one yourself really easily:

import torch

class WeightedMultilabel(torch.nn.Module):
    def __init__(self, weights: torch.Tensor):
        self.loss = torch.nn.BCEWithLogitsLoss()
        self.weights = weights.unsqueeze()

    def forward(outputs, targets):
        return self.loss(outputs, targets) * self.weights

Tensor has to be of the same length as the number of classes in your multilabel classification (270), each giving weight for your specific example.

Calculating weights

You just add labels of every sample in your dataset, divide by the minimum value and inverse at the end.

Sort of snippet:

weights = torch.zeros_like(dataset[0])
for element in dataset:
    weights += element

weights = 1 / (weights / torch.min(weights))

Using this approach class occurring the least will give normal loss, while others will have weights smaller than 1.

It might cause some instability during training though, so you might want to experiment with those values a little (maybe log transform instead of linear?)

Other approach

You may think about upsampling/downsampling (though this operation is complicated as you would add/delete other classes as well, so advanced heuristics would be needed I think).

like image 21
Szymon Maszke Avatar answered Jan 01 '23 18:01

Szymon Maszke


Maybe is a little late, but here is how I calculate the same. Looking into the documentation:

For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100 = 3.

So an easy way to calcule the positive weight is using the tensor methods with your label vector "y", in my case train_dataset.data.y. And then calculating the total negative labels.

num_positives = torch.sum(train_dataset.data.y, dim=0)
num_negatives = len(train_dataset.data.y) - num_positives
pos_weight  = num_negatives / num_positives

Then the weights can be used easily as:

criterion = torch.nn.BCEWithLogitsLoss(pos_weight = pos_weight)
like image 33
AlfilAlex Avatar answered Jan 01 '23 17:01

AlfilAlex