Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I compute bootstrapped cross entropy loss in PyTorch?

I have read some papers that use something called "Bootstrapped Cross Entropy Loss" to train their segmentation network. The idea is to focus only on the hardest k% (say 15%) of the pixels into account to improve learning performance, especially when easy pixels dominate.

Currently, I am using the standard cross entropy:

loss = F.binary_cross_entropy(mask, gt)

How do I convert this to the bootstrapped version efficiently in PyTorch?

like image 890
hkchengrex Avatar asked Sep 18 '25 12:09

hkchengrex


1 Answers

Often we would also add a "warm-up" period to the loss such that the network can learn to adapt to the easy regions first and transit to the harder regions.

This implementation starts from k=100 and continues for 20000 iterations, then linearly decay it to k=15 for another 50000 iterations.

class BootstrappedCE(nn.Module):
    def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
        super().__init__()

        self.start_warm = start_warm
        self.end_warm = end_warm
        self.top_p = top_p

    def forward(self, input, target, it):
        if it < self.start_warm:
            return F.cross_entropy(input, target), 1.0

        raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
        num_pixels = raw_loss.numel()

        if it > self.end_warm:
            this_p = self.top_p
        else:
            this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
        loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
        return loss.mean(), this_p
like image 99
hkchengrex Avatar answered Sep 21 '25 15:09

hkchengrex