I am performing a NLP task where I analyze a document and classify it into one of six categories. However, I do this operation at three different time periods. So the final output is an array of three integers (sparse), where each integer is the category 0-5. So a label looks like this: [1, 4, 5]
.
I am using BERT and am trying to decide what type of head I should attach to it, as well as what type of loss function I should use. Would it make sense to use BERT's output of size 1024
and run it through a Dense
layer with 18 neurons, then reshape into something of size (3,6)
?
Finally, I assume I would use Sparse Categorical Cross-Entropy as my loss function?
The bert final hidden state is (512,1024). You can either take the first token which is the CLS token or take the average pooling. Either way your final output is shape (1024,) now simply put 3 linear layers of shape (1024,6) as in nn.Linear(1024,6)
and pass it into the loss function below. (you can make it more complex if you want to)
Simply add up the loss and call backward. Remember you can call loss.backward() on any scalar tensor.(pytorch)
def loss(time1output,time2output,time3output,time1label,time2label,time3label):
loss1 = nn.CrossEntropyLoss()(time1output,time1label)
loss2 = nn.CrossEntropyLoss()(time2output,time2label)
loss3 = nn.CrossEntropyLoss()(time3output,time3label)
return loss1 + loss2 + loss3
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With