How do you compute the top k accuracy in semantic segmentation? In classification, we might compute the topk accuracy as:
correct = output.eq(gt.view(1, -1).expand_as(output))
You are looking for torch.topk
function that computes the top k values along a dimension.
The second output of torch.topk
is the "arg top k": the k indices of the top values.
Here's how this can be used in the context of semantic segmentation:
Suppose you have the ground truth prediction tensor y
of shape b
-h
-w
(dtype=torch.int64
).
Your model predicts per-pixel class logits
of shape b
-c
-h
-w
, with c
is the number of classes (including "background"). These logits are the "raw" predictions before softmax
function transforms them into class probabilities.
Since we are only looking at the top k
, it does not matter if the predictions are "raw" or "probabilities".
# compute the top k predicted classes, per pixel:
_, tk = torch.topk(logits, k, dim=1)
# you now have k predictions per pixel, and you want that one of them will match the true labels y:
correct_pixels = torch.eq(y[:, None, ...], tk).any(dim=1)
# take the mean of correct_pixels to get the overall average top-k accuracy:
top_k_acc = correct_pixels.mean()
Note that this method does not take into account "ignore" pixels. This can be done with a slight modification to the above code:
valid = y != ignore_index
top_k_acc = correct_pixels[valid].mean()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With