Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Count Unique elements in pytorch Tensor

Tags:

pytorch

Suppose I have the following tensor: y = torch.randint(0, 3, (10,)). How would you go about counting the 0's 1's and 2's in there?

The only way I can think of is by using collections.Counter(y) but was wondering if there was a more "pytorch" way of doing this. A use case for example would be when building the confusion matrix for predictions.

like image 296
sachinruk Avatar asked Apr 01 '26 18:04

sachinruk


1 Answers

You can use torch.unique with the return_counts option:

>>> x = torch.randint(0, 3, (10,))
tensor([1, 1, 0, 2, 1, 0, 1, 1, 2, 1])

>>> x.unique(return_counts=True)
(tensor([0, 1, 2]), tensor([2, 6, 2]))
like image 83
Ivan Avatar answered Apr 03 '26 16:04

Ivan