Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Partial slices in pytorch / numpy with arbitrary and variable number of dimensions

Given an 2-dimensional tensor in numpy (or in pytorch), I can partially slice along all dimensions at once as follows:

>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1:,1:]
array([[ 5,  6,  7],
       [ 9, 10, 11]])

How can I achieve the same slicing pattern regardless of the number of dimensions in the tensor if I do not know the number of dimensions at implementation time? (i.e. I want a[1:] if a has only one dimension, a[1:,1:] for two dimensions, a[1:,1:,1:] for three dimensions, and so on)

It would be nice if I could do it in a single line of code like the following, but this is invalid:

a[(1:,) * len(a.shape)]  # SyntaxError: invalid syntax

I am specifically interested in a solution that works for pytorch tensors (just substitute torch for numpy above and the example is the same), but I figure it is likely and best if the solution works for both numpy and pytorch.

like image 210
teichert Avatar asked Oct 27 '25 10:10

teichert


1 Answers

Answer: Making a tuple of slice objects does the trick:

a[(slice(1,None),) * len(a.shape)]

Explanation: slice is a builtin python class (not tied to numpy or pytorch) which provides an alternative to the subscript notation for describing slices. The answer to a different question suggests using this as a way to store slice information in python variables. The python glossary points out that

The bracket (subscript) notation uses slice objects internally.

Since the __getitem__ methods for numpy ndarrays and pytorch tensors support multi-dimensional indexing with slices, they must also support multidimensional indexing with slice objects, and so we can make a tuple of those slices that of the right length.

Btw, you can see how python uses the slice objects by creating a dummy class as follows and then do slicing on it:

class A(object):
    def __getitem__(self, ix):
        return ix

print(A()[5])  # 5
print(A()[1:])  # slice(1, None, None)
print(A()[1:,1:])  # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)])  #  (slice(1, None, None), slice(1, None, None))


like image 165
teichert Avatar answered Oct 29 '25 00:10

teichert



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!