Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Unique values in PyTorch tensor

I'm tying to find distinct values in a PyTorch tensor. Is there an efficient analogue of Tensorflow's unique op?

like image 644
arosa Avatar asked Jul 09 '17 05:07

arosa


3 Answers

There is a torch.unique() method in 0.4.0

In torch <= 0.3.1 you can try:

import torch
import numpy as np

x = torch.rand((3,3)) * 10
np.unique(x.round().numpy())
like image 50
Alex Glinsky Avatar answered Nov 07 '22 17:11

Alex Glinsky


You can convert to a numpy array and make use of numpy's built-in unique function:

def unique(tensor1d):
    t, idx = np.unique(tensor1d.numpy(), return_inverse=True)
    return torch.from_numpy(t), torch.from_numpy(idx)  

Example:

t, idx = unique(torch.LongTensor([1, 1, 2, 4, 4, 4, 7, 8, 8]))  
# t --> [1, 2, 4, 7, 8]
# idx --> [0, 0, 1, 2, 2, 2, 3, 4, 4]
like image 31
entrophy Avatar answered Nov 07 '22 16:11

entrophy


  1. get common items between two tensors with torch.eq()
  2. fetch indices and concatenate tensors
  3. finally get common items via torch.unique:
import torch as pt

a = pt.tensor([1,2,3,2,3,4,3,4,5,6])
b = pt.tensor([7,2,3,2,7,4,9,4,9,8])

equal_data = pt.eq(a, b)
pt.unique(pt.cat([a[equal_data],b[equal_data]]))
like image 1
Sunil Avatar answered Nov 07 '22 17:11

Sunil