Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy Apply Along Axis and Get Row Index

Tags:

python

numpy

I have a 2D array (it is actually very large and a view of another array):

x = np.array([[0, 1, 2],
          [1, 2, 3],
          [2, 3, 4],
          [3, 4, 5]]
        )

And I have a function that processes each row of the array:

def some_func(a):
    """
    Some function that does something funky with a row of numbers
    """
    return [a[2], a[0]]  # This is not so funky

np.apply_along_axis(some_func, 1, x)

What I'm looking for is some way to call the np.apply_along_axis function so that I have access to the row index (for the row being processed) and then be able to process each row with this function:

def some_func(a, idx):
    """
    I plan to use the index for some logic on which columns to
    return. This is only an example
    """
    return [idx, a[2], a[0]]  # This is not so funky
like image 445
slaw Avatar asked Mar 01 '17 20:03

slaw


1 Answers

For a 2d array with axis=1, apply_along_axis is the same as iteration of the rows of the array

In [149]: np.apply_along_axis(some_func, 1, x)
Out[149]: 
array([[2, 0],
       [3, 1],
       [4, 2],
       [5, 3]])
In [151]: np.array([some_func(i) for i in x])
Out[151]: 
array([[2, 0],
       [3, 1],
       [4, 2],
       [5, 3]])

For axis=0 we could iterate on x.T. apply_along_axis is more useful when the array is 3d, and we want to iterate on all dimensions except one. Then it saves on some tedium. But it isn't a speed solution.

With your revised function, we can use standard enumerate to get both row and index:

In [153]: np.array([some_func(v,i) for i,v in enumerate(x)])
Out[153]: 
array([[0, 2, 0],
       [1, 3, 1],
       [2, 4, 2],
       [3, 5, 3]])

or with a simple range iteration:

In [157]: np.array([some_func(x[i],i) for i in range(x.shape[0])])
Out[157]: 
array([[0, 2, 0],
       [1, 3, 1],
       [2, 4, 2],
       [3, 5, 3]])

There are various tools for getting the indexing for higher dimenions, things like ndenumerate and ndindex.

The fast solution - work on all rows at once:

In [158]: np.column_stack((np.arange(4), x[:,2], x[:,0]))
Out[158]: 
array([[0, 2, 0],
       [1, 3, 1],
       [2, 4, 2],
       [3, 5, 3]])
like image 58
hpaulj Avatar answered Oct 15 '22 11:10

hpaulj