There is a famous trick in u-net architecture to use custom weight maps to increase accuracy. Below are the details of it:
Now, by asking here and at multiple other place, I get to know about 2 approaches. I want to know which one is correct or is there any other right approach which is more correct?
First is to use torch.nn.Functional
method in the training loop:
loss = torch.nn.functional.cross_entropy(output, target, w)
where w will be the calculated custom weight.
Second is to use reduction='none'
in the calling of loss function outside the training loop
criterion = torch.nn.CrossEntropy(reduction='none')
and then in the training loop multiplying with the custom weight:
gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch
Now, I am kinda confused which one is right or is there any other way, or both are right?
The weighting portion looks like just simply weighted cross entropy which is performed like this for the number of classes (2 in the example below).
weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)
EDIT:
Have you seen this implementation from Patrick Black?
# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10
# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)
# Calculate log probabilities
logp = F.log_softmax(logits)
# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))
# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)
# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)
# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With