Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

torch logical indexing of tensor

Tags:

indexing

torch

I looking for an elegant way to select a subset of a torch tensor which satisfies some constrains. For example, say I have:

A = torch.rand(10,2)-1

and S is a 10x1 tensor,

sel = torch.ge(S,5) -- this is a ByteTensor

I would like to be able to do logical indexing, as follows:

A1 = A[sel]

But that doesn't work. So there's the index function which accepts a LongTensor but I could not find a simple way to convert S to a LongTensor, except the following:

sel = torch.nonzero(sel)

which returns a K x 2 tensor (K being the number of values of S >= 5). So then I have to convert it to a 1 dimensional array, which finally allows me to index A:

A:index(1,torch.squeeze(sel:select(2,1)))

This is very cumbersome; in e.g. Matlab all I'd have to do is

A(S>=5,:)

Can anyone suggest a better way?

like image 385
Amir Rosenfeld Avatar asked Mar 31 '16 20:03

Amir Rosenfeld


People also ask

Can you index a tensor?

Single element indexing for a 1-D tensors works mostly as expected. Like R, it is 1-based. Unlike R though, it accepts negative indices for indexing from the end of the array. (In R, negative indices are used to remove elements.)

What is PyTorch index tensor?

Indexing a Pytorch tensor is similar to that of a Python list. The pytorch tensor indexing is 0 based, i.e, the first element of the array has index 0.

How do you slice a torch tensor?

Slicing a 3D Tensor Slicing: Slicing means selecting the elements present in the tensor by using “:” slice operator. We can slice the elements by using the index of that particular element. Parameters: tensor_position_start: Specifies the Tensor to start iterating.

What is Torch Meshgrid?

torch. meshgrid (*tensors, indexing=None)[source] Creates grids of coordinates specified by the 1D inputs in attr :tensors. This is helpful when you want to visualize data over some range of inputs.


1 Answers

One possible alternative is:

sel = S:ge(5):expandAs(A)   -- now you can use this mask with the [] operator
A1 = A[sel]:unfold(1, 2, 2) -- unfold to get back a 2D tensor

Example:

> A = torch.rand(3,2)-1
-0.0047 -0.7976
-0.2653 -0.4582
-0.9713 -0.9660
[torch.DoubleTensor of size 3x2]

> S = torch.Tensor{{6}, {1}, {5}}
 6
 1
 5
[torch.DoubleTensor of size 3x1]

> sel = S:ge(5):expandAs(A)
1  1
0  0
1  1
[torch.ByteTensor of size 3x2]

> A[sel]
-0.0047
-0.7976
-0.9713
-0.9660
[torch.DoubleTensor of size 4]

> A[sel]:unfold(1, 2, 2)
-0.0047 -0.7976
-0.9713 -0.9660
[torch.DoubleTensor of size 2x2]
like image 146
deltheil Avatar answered Oct 14 '22 06:10

deltheil