Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification

I am working on Multiclass Classification (4 classes) for Language Task and I am using the BERT model for classification task. I am following this blog as reference. My BERT Fine Tuned model returns nn.LogSoftmax(dim=1).

My data is pretty imbalanced so I used sklearn.utils.class_weight.compute_class_weight to compute weights of the classes and used the weights inside the Loss.

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

My results were not so good so I thought of Experementing with Focal Loss and have a code for Focal Loss.

class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss

I have 3 questions now. First and the Most important is

  1. Should I use Class Weight with Focal Loss?
  2. If I have to Implement weights inside this Focal Loss, can I use weights parameters inside nn.CrossEntropyLoss()
  3. If this implement is incorrect, what should be the proper code for this one including the weights (if possible)
like image 895
Deshwal Avatar asked Nov 09 '20 11:11

Deshwal


People also ask

Can focal loss be used for classification?

Binary classificationFocal Loss can be interpreted as a binary cross-entropy function multiplied by a modulating factor (1- pₜ)^γ which reduces the contribution of easy-to-classify samples. The weighting factor aₜ balances the modulating factor.

How focal loss fixes the class imbalance?

Focal loss is proposed to address imbalanced classes by reshaping and modifying the standard cross-entropy based on the loss function to obtain better classifi- cation. Focal loss is applied to solve this imbalanced data issue in computer vision and achieves excellent performance.

What is focal loss and when should you use it?

In simple words, Focal Loss (FL) is an improved version of Cross-Entropy Loss (CE) that tries to handle the class imbalance problem by assigning more weights to hard or easily misclassified examples (i.e. background with noisy texture or partial object or the object of our interest ) and to down-weight easy examples ( ...

How much focal loss for imbalanced data using PyTorch?

0 focal loss for imbalanced data using pytorch 1 Using Focal Loss for imbalanced dataset in PyTorch Hot Network Questions Multiple root of resultant The difference between "…to {get, arrive, reach, come} here on time."

What are the weights required for the focal loss?

Focal loss automatically handles the class imbalance, hence weights are not required for the focal loss. The alpha and gamma factors handle the class imbalance in the focal loss equation. No need of extra weights because focal loss handles them using alpha and gamma modulating factors

What is the categorical cross entropy loss function in PyTorch?

The categorical cross entropy loss function for one data point is where y=1,0 for positive and negative labels, p is the probability for positive class and w1 and w0 are the class weights for positive class and negative class. For a minibatch the implementation for PyTorch and Tensorflow differ by a normalization. PyTorch has

Is there a way to use class weights in PyTorch?

Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler in PyTorch, using the same aforementioned weights. Thanks for contributing an answer to Data Science Stack Exchange!


Video Answer


2 Answers

You may find answers to your questions as follows:

  1. Focal loss automatically handles the class imbalance, hence weights are not required for the focal loss. The alpha and gamma factors handle the class imbalance in the focal loss equation.
  2. No need of extra weights because focal loss handles them using alpha and gamma modulating factors
  3. The implementation you mentioned is correct according to the focal loss formula but I had trouble in causing my model to converge with this version hence, I used the following implementation from mmdetection framework
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

You can also experiment with another focal loss version available

like image 161
user3411639 Avatar answered Oct 08 '22 16:10

user3411639


I think OP would've gotten his answer by now. I am writing this for other people who might ponder upon this.

There in one problem in OPs implementation of Focal Loss:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

In this line, the same alpha value is multiplied with every class output probability i.e. (pt). Additionally, code doesn't show how we get pt. A very good implementation of Focal Loss could be find here. But this implementation is only for binary classification as it has alpha and 1-alpha for two classes in self.alpha tensor.

In case of multi-class classification or multi-label classification, self.alpha tensor should contain number of elements equal to the total number of labels. The values could be inverse label frequency of labels or inverse label normalized frequency (just be cautious with labels which has 0 as frequency).

like image 2
Ashish Avatar answered Oct 08 '22 15:10

Ashish