Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch - better way to get back original tensor order after torch.sort

Tags:

python

pytorch

I want to get back the original tensor order after a torch.sort operation and some other modifications to the sorted tensor, so that the tensor is not anymore sorted. It is better to explain this with an example:

x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices) 
# final must be equal to torch.tanh(x)

I have implemented the function in this way:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    for i in range(ordered.size(0)):
        z[indices[i]] = ordered[i]
    return z

Is there a better way to do this? In particular, it is possible to avoid the loop and compute the operation more efficiently?

In my case I have a tensor of size torch.Size([B, N]) and I sort each of the B rows separately with a single call of torch.sort. So, I have to call original_order B times with another loop.

Any, more pytorch-ic, ideas?

EDIT 1 - Get rid of inner loop

I solved part of the problem by simply indexing z with indices in this way:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    z[indices] = ordered
    return z

Now, I just have to understand how to avoid the outer loop on B dimension.

EDIT 2 - Get rid of outer loop

def original_order(ordered, indices, batch_size):
    # produce a vector to shift indices by lenght of the vector 
    # times the batch position
    add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)


    indices = indices + add.long().view(-1,1)

    # reduce tensor to single dimension. 
    # Now the indices take in consideration the new length
    long_ordered = ordered.view(-1)
    long_indices = indices.view(-1)

    # we are in the previous case with one dimensional vector
    z = torch.zeros_like(long_ordered).float()
    z[long_indices] = long_ordered

    # reshape to get back to the correct dimension
    return z.view(batch_size, -1)
like image 714
Bobby Avatar asked Sep 01 '18 11:09

Bobby


1 Answers

def original_order(ordered, indices):
    return ordered.gather(1, indices.argsort(1))

Example

original = torch.tensor([
    [20, 22, 24, 21],
    [12, 14, 10, 11],
    [34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))

Why it works

For simplicity, imagine t = [30, 10, 20], omitting tensor notation.

t.sort() gives us the sorted tensor s = [10, 20, 30], as well as the sorting index i = [1, 2, 0] for free. i is in fact the output of t.argsort().

i tells us how to go from t to s. "To sort t into s, take element 1, then 2, then 0, from t". Argsorting i gives us another sorting index j = [2, 0, 1], which tells us how to go from i to the canonical sequence of natural numbers [0, 1, 2], in effect reversing the sort. Another way to look at it is that j tells us how to go from s to t. "To sort s into t, take element 2, then 0, then 1, from s". Argsorting a sorting index gives us its "inverse index", going the other way.

Now that we have the inverse index, we dump that into torch.gather() with the correct dim, and that unsorts the tensor.

Sources

torch.gather torch.argsort

I couldn't find this exact solution when researching this problem, so I think this is an original answer.

like image 75
qmk Avatar answered Nov 02 '22 04:11

qmk