Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

groupby aggregate mean in pytorch

Tags:

pytorch

I have a 2D tensor:

samples = torch.Tensor([
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
])

and a label for each sample corresponding to a class:

labels = torch.LongTensor([1, 2, 2, 0])

so len(samples) == len(labels). Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension [n_classes, samples.shape[1]] So the expected solution should be:

result == torch.Tensor([
    [0.1, 0.1],
    [0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
    [0.0, 0.0]
])

Question: How can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?

like image 412
elyase Avatar asked Jan 26 '23 07:01

elyase


1 Answers

All you need to do is form an mxn matrix (m=num classes, n=num samples) which will select the appropriate weights, and scale the mean appropriately. Then you can perform a matrix multiplication between your newly formed matrix and the samples matrix.

Given your labels, your matrix should be (each row is a class number, each class a sample number and its weight):

[[0.0000, 0.0000, 0.0000, 1.0000],
 [1.0000, 0.0000, 0.0000, 0.0000],
 [0.0000, 0.5000, 0.5000, 0.0000]]

Which you can form as follows:

M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(len(samples)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)

Output:

tensor([[0.0000, 0.0000],
        [0.1000, 0.1000],
        [0.3000, 0.3000]])

Note that the output means are correctly sorted in class order.

Why does M[labels, torch.arange(len(samples))] = 1 work?

This is performing a broadcast operation between the labels and the number of samples. Essentially, we are generating a 2D index for every element in labels: the first specifies which of the m classes it belongs to, and the second simply specifies its index position (from 1 to N). Another way would be top explicitly generate all the 2D indices:

twoD_indices = []
for count, label in enumerate(labels):
  twoD_indices.append((label, count))
like image 85
yhenon Avatar answered Feb 01 '23 03:02

yhenon