How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification

I am working on Multiclass Classification (4 classes) for Language Task and I am using the BERT model for classification task. I am following this blog as reference. My BERT Fine Tuned model returns nn.LogSoftmax(dim=1).

My data is pretty imbalanced so I used sklearn.utils.class_weight.compute_class_weight to compute weights of the classes and used the weights inside the Loss.

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

My results were not so good so I thought of Experementing with Focal Loss and have a code for Focal Loss.

class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
      return F_loss

I have 3 questions now. First and the Most important is

  1. Should I use Class Weight with Focal Loss?
  2. If I have to Implement weights inside this Focal Loss, can I use weights parameters inside nn.CrossEntropyLoss()
  3. If this implement is incorrect, what should be the proper code for this one including the weights (if possible)
2 Answers

You may find answers to your questions as follows:

  1. Focal loss automatically handles the class imbalance, hence weights are not required for the focal loss. The alpha and gamma factors handle the class imbalance in the focal loss equation.
  2. No need of extra weights because focal loss handles them using alpha and gamma modulating factors
  3. The implementation you mentioned is correct according to the focal loss formula but I had trouble in causing my model to converge with this version hence, I used the following implementation from mmdetection framework
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

You can also experiment with another focal loss version available

I think OP would've gotten his answer by now. I am writing this for other people who might ponder upon this.

There in one problem in OPs implementation of Focal Loss:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

In this line, the same alpha value is multiplied with every class output probability i.e. (pt). Additionally, code doesn't show how we get pt. A very good implementation of Focal Loss could be find here. But this implementation is only for binary classification as it has alpha and 1-alpha for two classes in self.alpha tensor.

In case of multi-class classification or multi-label classification, self.alpha tensor should contain number of elements equal to the total number of labels. The values could be inverse label frequency of labels or inverse label normalized frequency (just be cautious with labels which has 0 as frequency).

