I want to index over a multi dimensional array like this:
a = range(12).reshape(3, 2, 2)
def fun(axis, state):
# if axis=0
return a[state, :, :]
# if axis=1 it should return a[:, state, :]
Sample outputs:
fun(0, 1)
array([[4, 5],
[6, 7]])
fun(1, 1)
array([[2, 3],
[6, 7],
[10, 11]])
In short I want to accept the axis as an argument.
I can't think of a way to do this. Any possible solutions?
You can take a view of the array with a specified axis moved to the front using numpy.rollaxis
:
def fun(a, axis, state):
return numpy.rollaxis(a, axis)[state]
Demo:
>>> a = numpy.arange(12).reshape([3, 2, 2])
>>> def fun(a, axis, state):
... return numpy.rollaxis(a, axis)[state]
...
>>> fun(a, 0, 1)
array([[4, 5],
[6, 7]])
>>> fun(a, 1, 1)
array([[ 2, 3],
[ 6, 7],
[10, 11]])
numpy.rollaxis
also supports moving axes to other positions, though the way it interprets the arguments for that is kind of weird.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With