Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get top k accuracy in semantic segmentation using PyTorch?

How do you compute the top k accuracy in semantic segmentation? In classification, we might compute the topk accuracy as:

correct = output.eq(gt.view(1, -1).expand_as(output))
like image 694
JJTT Avatar asked Dec 25 '19 03:12

JJTT


1 Answers

You are looking for torch.topk function that computes the top k values along a dimension.
The second output of torch.topk is the "arg top k": the k indices of the top values.

Here's how this can be used in the context of semantic segmentation:
Suppose you have the ground truth prediction tensor y of shape b-h-w (dtype=torch.int64).
Your model predicts per-pixel class logits of shape b-c-h-w, with c is the number of classes (including "background"). These logits are the "raw" predictions before softmax function transforms them into class probabilities. Since we are only looking at the top k, it does not matter if the predictions are "raw" or "probabilities".

# compute the top k predicted classes, per pixel:
_, tk = torch.topk(logits, k, dim=1)
# you now have k predictions per pixel, and you want that one of them will match the true labels y:
correct_pixels = torch.eq(y[:, None, ...], tk).any(dim=1)
# take the mean of correct_pixels to get the overall average top-k accuracy:
top_k_acc = correct_pixels.mean()  

Note that this method does not take into account "ignore" pixels. This can be done with a slight modification to the above code:

valid = y != ignore_index
top_k_acc = correct_pixels[valid].mean()
like image 62
Shai Avatar answered Oct 24 '22 12:10

Shai