How do you compute the top k accuracy in semantic segmentation? In classification, we might compute the topk accuracy as:
correct = output.eq(gt.view(1, -1).expand_as(output))
You are looking for torch.topk function that computes the top k values along a dimension.
The second output of torch.topk is the "arg top k": the k indices of the top values.
Here's how this can be used in the context of semantic segmentation:
Suppose you have the ground truth prediction tensor y of shape b-h-w (dtype=torch.int64).
Your model predicts per-pixel class logits of shape b-c-h-w, with c is the number of classes (including "background"). These logits are the "raw" predictions before softmax function transforms them into class probabilities.
Since we are only looking at the top k, it does not matter if the predictions are "raw" or "probabilities".
# compute the top k predicted classes, per pixel:
_, tk = torch.topk(logits, k, dim=1)
# you now have k predictions per pixel, and you want that one of them will match the true labels y:
correct_pixels = torch.eq(y[:, None, ...], tk).any(dim=1)
# take the mean of correct_pixels to get the overall average top-k accuracy:
top_k_acc = correct_pixels.mean()
Note that this method does not take into account "ignore" pixels. This can be done with a slight modification to the above code:
valid = y != ignore_index
top_k_acc = correct_pixels[valid].mean()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With