Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I slice a PyTorch tensor with another tensor?

I have:

inp =  torch.randn(4, 1040, 161)

and I have another tensor called indices with values:

tensor([[124, 583, 158, 529],
        [172, 631, 206, 577]], device='cuda:0')

I want the equivalent of:

inp0 = inp[:,124:172,:]
inp1 = inp[:,583:631,:]
inp2 = inp[:,158:206,:]
inp3 = inp[:,529:577,:]

Except all added together, to have a .size of [4, 48, 161]. How can I accomplish this?

Currently, my solution is a for loop:

            left_indices = torch.empty(inp.size(0), self.side_length, inp.size(2))
            for batch_index in range(len(inp)):
                print(left_indices_start[batch_index].item())
                left_indices[batch_index] = inp[batch_index, left_indices_start[batch_index].item():left_indices_end[batch_index].item()]
like image 355
Shamoon Avatar asked Oct 15 '22 04:10

Shamoon


2 Answers

Here you go (EDIT: you probably need to copy tensors to cpu using tensor=tensor.cpu() before doing following operations):

index = tensor([[124, 583, 158, 529],
    [172, 631, 206, 577]], device='cuda:0')
#create a concatenated list of ranges of indices you desire to slice
indexer = np.r_[tuple([np.s_[i:j] for (i,j) in zip(index[0,:],index[1,:])])]
#slice using numpy indexing
sliced_inp = inp[:, indexer, :]

Here is how it works:

np.s_[i:j] creates a slice object (simply a range) of indices from start=i to end=j.

np.r_[i:j, k:m] creates a list ALL indices in slices (i,j) and (k,m) (You can pass more slices to np.r_ to concatenate them all together at once. This is an example of concatenating only two slices.)

Therefore, indexer creates a list of ALL indices by concatenating a list of slices (each slice is a range of indices).

UPDATE: If you need to remove interval overlaps and sort intervals:

indexer = np.unique(indexer)

if you want to remove interval overlaps but not sort and keep original order (and first occurrences of overlaps)

uni = np.unique(indexer, return_index=True)[1]
indexer = [indexer[index] for index in sorted(uni)]
like image 119
Ehsan Avatar answered Oct 18 '22 15:10

Ehsan


inp =  torch.randn(4, 1040, 161)   
indices = torch.tensor([[124, 583, 158, 529],
            [172, 631, 206, 577]])
k = zip(indices[0], indices[1])
for i,j in k:
    print(inp[:,i:j,:])

You can implement it like this ... zip function helps to convert your indices tensor to list of tuples which you can use directly via for loop

Hope it helps you out....

like image 22
Neha_Jain Avatar answered Oct 18 '22 15:10

Neha_Jain