Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch differences between two tensors

I have two tensors like this:

1st tensor
[[0,0],[0,1],[0,2],[1,3],[1,4],[2,1],[2,4]]

2nd tensor
[[0,1],[0,2],[1,4],[2,4]]

I want the result tensor to be like this:

[[0,0],[1,3],[2,1]] # differences between 1st tensor and 2nd tensor

I have tried to use set, list, torch.where,.. and couldn't find any good way to achieve this. Is there any way to get the different rows between two different sizes of tensors? (need to be efficient)

like image 686
user19283043 Avatar asked Mar 03 '23 07:03

user19283043


2 Answers

You can perform a pairwairse comparation to see which elements of the first tensor are present in the second vector.

a = torch.as_tensor([[0,0],[0,1],[0,2],[1,3],[1,4],[2,1],[2,4]])
b = torch.as_tensor([[0,1],[0,2],[1,4],[2,4]])

# Expand a to (7, 1, 2) to broadcast to all b
a_exp = a.unsqueeze(1)

# c: (7, 4, 2) 
c = a_exp == b
# Since we want to know that all components of the vector are equal, we reduce over the last fim
# c: (7, 4)
c = c.all(-1)
print(c)
# Out: Each row i compares the ith element of a against all elements in b
# Therefore, if all row is false means that the a element is not present in b
tensor([[False, False, False, False],
        [ True, False, False, False],
        [False,  True, False, False],
        [False, False, False, False],
        [False, False,  True, False],
        [False, False, False, False],
        [False, False, False,  True]])
non_repeat_mask = ~c.any(-1)

# Apply the mask to a
print(a[non_repeat_mask])
tensor([[0, 0],
        [1, 3],
        [2, 1]])

If you feel cool you can do it one liner :)

a[~a.unsqueeze(1).eq(b).all(-1).any(-1)]
like image 172
Guillem Avatar answered Mar 19 '23 07:03

Guillem


In case someone is looking for a solution with a vector of dim=1, this is the adaptation of @Guillem solution

a = torch.tensor(list(range(0, 10)))
b = torch.tensor(list(range(5,15)))

a[~a.unsqueeze(1).eq(b).any(1)]

outputs:

tensor([0, 1, 2, 3, 4])

Here is another solution, when you want the absolute difference, and not just comparing the first with the second. Be careful when using it, because order here doesnt matter

combined = torch.cat((a, b))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]

outputs

tensor([ 0,  1,  2,  3,  4, 10, 11, 12, 13, 14])
like image 39
Felipe Mello Avatar answered Mar 19 '23 05:03

Felipe Mello