I want to get back the original tensor order after a torch.sort
operation and some other modifications to the sorted tensor, so that the tensor is not anymore sorted. It is better to explain this with an example:
x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices)
# final must be equal to torch.tanh(x)
I have implemented the function in this way:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
for i in range(ordered.size(0)):
z[indices[i]] = ordered[i]
return z
Is there a better way to do this? In particular, it is possible to avoid the loop and compute the operation more efficiently?
In my case I have a tensor of size torch.Size([B, N])
and I sort each of the B
rows separately with a single call of torch.sort
. So, I have to call original_order
B
times with another loop.
Any, more pytorch-ic, ideas?
EDIT 1 - Get rid of inner loop
I solved part of the problem by simply indexing z with indices in this way:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
z[indices] = ordered
return z
Now, I just have to understand how to avoid the outer loop on B
dimension.
EDIT 2 - Get rid of outer loop
def original_order(ordered, indices, batch_size):
# produce a vector to shift indices by lenght of the vector
# times the batch position
add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)
indices = indices + add.long().view(-1,1)
# reduce tensor to single dimension.
# Now the indices take in consideration the new length
long_ordered = ordered.view(-1)
long_indices = indices.view(-1)
# we are in the previous case with one dimensional vector
z = torch.zeros_like(long_ordered).float()
z[long_indices] = long_ordered
# reshape to get back to the correct dimension
return z.view(batch_size, -1)
def original_order(ordered, indices):
return ordered.gather(1, indices.argsort(1))
original = torch.tensor([
[20, 22, 24, 21],
[12, 14, 10, 11],
[34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))
For simplicity, imagine t = [30, 10, 20]
, omitting tensor notation.
t.sort()
gives us the sorted tensor s = [10, 20, 30]
, as well as the sorting index i = [1, 2, 0]
for free. i
is in fact the output of t.argsort()
.
i
tells us how to go from t
to s
. "To sort t
into s
, take element 1, then 2, then 0, from t
". Argsorting i
gives us another sorting index j = [2, 0, 1]
, which tells us how to go from i
to the canonical sequence of natural numbers [0, 1, 2]
, in effect reversing the sort. Another way to look at it is that j
tells us how to go from s
to t
. "To sort s
into t
, take element 2, then 0, then 1, from s
". Argsorting a sorting index gives us its "inverse index", going the other way.
Now that we have the inverse index, we dump that into torch.gather()
with the correct dim
, and that unsorts the tensor.
torch.gather
torch.argsort
I couldn't find this exact solution when researching this problem, so I think this is an original answer.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With