Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy index on a variable axis

I want to index over a multi dimensional array like this:

a = range(12).reshape(3, 2, 2)
def fun(axis, state):
    # if axis=0
    return a[state, :, :]
    # if axis=1 it should return a[:, state, :]

Sample outputs:

fun(0, 1)
array([[4, 5],                                 
       [6, 7]]) 

fun(1, 1)
array([[2, 3],
       [6, 7],
       [10, 11]])

In short I want to accept the axis as an argument.

I can't think of a way to do this. Any possible solutions?

like image 647
Ankur Ankan Avatar asked Sep 07 '25 18:09

Ankur Ankan


1 Answers

You can take a view of the array with a specified axis moved to the front using numpy.rollaxis:

def fun(a, axis, state):
    return numpy.rollaxis(a, axis)[state]

Demo:

>>> a = numpy.arange(12).reshape([3, 2, 2])
>>> def fun(a, axis, state):
...     return numpy.rollaxis(a, axis)[state]
...
>>> fun(a, 0, 1)
array([[4, 5],
       [6, 7]])
>>> fun(a, 1, 1)
array([[ 2,  3],
       [ 6,  7],
       [10, 11]])

numpy.rollaxis also supports moving axes to other positions, though the way it interprets the arguments for that is kind of weird.

like image 193
user2357112 supports Monica Avatar answered Sep 11 '25 07:09

user2357112 supports Monica