I have a multi-label classification problem. I have 11 classes, around 4k examples. Each example can have from 1 to 4-5 label. At the moment, i'm training a classifier separately for each class with log_loss. As you can expect, it is taking quite some time to train 11 classifier, and i would like to try another approach and to train only 1 classifier. The idea is that the last layer of this classifer would have 11 nodes, and would output a real number by classes which would be converted to a proba by a sigmoid. The loss I want to optimize is the mean of the log_loss on all classes.
Unfortunately, i'm some kind of noob with pytorch, and even by reading the source code of the losses, i can't figure out if one of the already existing losses does exactly what i want, or if I should create a new loss, and if that's the case, i don't really know how to do it.
To be very specific, i want to give for each element of the batch one vector of size 11(which contains a real number for each label (the closer to infinity, the closer this class is predicted to be 1), and 1 vector of size 11 (which contains a 1 at every true label), and be able to compute the mean log_loss on all 11 labels, and optimize my classifier based on that loss.
Any help would be greatly appreciated :)
Difference between multi-class classification & multi-label classification is that in multi-class problems the classes are mutually exclusive, whereas for multi-label problems each label represents a different classification task, but the tasks are somehow related.
Multi-output classification is a type of machine learning that predicts multiple outputs simultaneously. In multi-output classification, the model will give two or more outputs after making any prediction. In other types of classifications, the model usually predicts only a single output.
You are looking for torch.nn.BCELoss
. Here's example code:
import torch
batch_size = 2
num_classes = 11
loss_fn = torch.nn.BCELoss()
outputs_before_sigmoid = torch.randn(batch_size, num_classes)
sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
target_classes = torch.randint(0, 2, (batch_size, num_classes)) # randints in [0, 2).
loss = loss_fn(sigmoid_outputs, target_classes)
# alternatively, use BCE with logits, on outputs before sigmoid.
loss_fn_2 = torch.nn.BCEWithLogitsLoss()
loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
assert loss == loss2
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With