Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to sort a tensor by first dimension in pytorch?

I have a 2D tensor and I would like to sort by the first dimension like this example:

a = torch.FloatTensor(
   [[5, 5],
    [5, 3],
    [3, 5],
    [6, 4],
    [3, 7]])

And I expected this result after sorting:

a = torch.FloatTensor(
   [[3, 5],
    [3, 7],
    [5, 3],
    [5, 5],
    [6, 4]])

Is it possible to do this in pytorch? I know that is possible to do it in numpy, but I want do it in GPU using torch.

like image 365
Fernando Noronha Avatar asked May 07 '20 18:05

Fernando Noronha


2 Answers

Sort by first column and use the indices to then sort the whole array:

a[a[:, 0].sort()[1]]

Output:

tensor([[3., 5.],
        [3., 7.],
        [5., 5.],
        [5., 3.],
        [6., 4.]])

And if you really need it interleaved:

b = a[a[:, 1].sort()[1]]
b[b[:, 0].sort()[1]]

Output:

tensor([[3., 5.],
        [3., 7.],
        [5., 3.],
        [5., 5.],
        [6., 4.]])
like image 90
cosmic_inquiry Avatar answered Sep 26 '22 00:09

cosmic_inquiry


torch.stack(sorted(a, key=lambda a: a[0]))

Output will be:

tensor([[3., 5.],
        [3., 7.],
        [5., 5.],
        [5., 3.],
        [6., 4.]])
like image 42
汉林张 Avatar answered Sep 24 '22 00:09

汉林张