Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Delete duplicated rows in torch.tensor

I have a torch.tensor of shape (n,m) and I want to remove the duplicated rows (or at least find them). For example:

t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
t2 = remove_duplicates(t1)

t2 should be now equal to tensor([[1, 2, 3], [4, 5, 6]]), that is rows 1 and 3 are removed. Do you know a way to perform this operation?

I was thinking to do something with torch.unique but I cannot figure out what to do.

like image 353
aretor Avatar asked Oct 21 '25 13:10

aretor


1 Answers

You can simply exploit the parameter dim of torch.unique.

t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.unique(t1, dim=0)

In this way you obtain the result you want:

tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])

Here you can read the meaning of that parameter.

like image 128
Erosinho Avatar answered Oct 23 '25 04:10

Erosinho



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!