Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy: summing along all but last axis

Tags:

python

numpy

If I have an ndarray of arbitrary shape and I would like to compute the sum along all but the last axis I can, for instance, achieve it by doing

all_but_last = tuple(range(arr.ndim - 1))
sum = arr.sum(axis=all_but_last)

Now, tuple(range(arr.ndim - 1)) is not exactly intuitive I feel. Is there a more elegant/numpy-esque way to do this?

Moreover, if I want to do this for multiple arrays of varying shape, I'll have to calculate a separate dimension tuple for each of them. Is there a more canonical way to say "regardless of what the dimensions are, just give me all but one axis"?

like image 542
DarthPumpkin Avatar asked Aug 06 '18 15:08

DarthPumpkin


People also ask

What does axis do in NP sum?

sum with the axis parameter, the function will sum the values along a particular axis. In particular, when we use np. sum with axis = 0 , the function will sum over the 0th axis (the rows). It's basically summing up the values row-wise, and producing a new array (with lower dimensions).

What is the difference between sum and NP sum?

Pythons sum iterates over the iterable (in this case the list or array) and adds all elements. NumPys sum method iterates over the stored C array and adds these C values and finally wraps that value in a Python type (in this case numpy. int32 (or numpy. int64 ) and returns it.

How do you find the sum of all NumPy elements?

Python numpy sum() function syntax The array elements are used to calculate the sum. If the axis is not provided, the sum of all the elements is returned. If the axis is a tuple of ints, the sum of all the elements in the given axes is returned. We can specify dtype to specify the returned output data type.


2 Answers

You can use np.apply_over_axes to sum over multiple axes.

np.apply_over_axes(np.sum, arr, [0,2]) #sum over axes 0 and 2

np.apply_over_axes(np.sum, arr, range(arr.ndim - 1)) #sum over all but last axis
like image 129
Easton Bornemeier Avatar answered Oct 05 '22 22:10

Easton Bornemeier


You could reshape the array so that all axes except the last are flattened (e.g. shape (k, l, m, n) becomes (k*l*m, n)), and then sum over the first axis.

For example, here's your calculation:

In [170]: arr.shape
Out[170]: (2, 3, 4)

In [171]: arr.sum(axis=tuple(range(arr.ndim - 1)))
Out[171]: array([2.85994792, 2.8922732 , 2.29051163, 2.77275709])

Here's the alternative:

In [172]: arr.reshape(-1, arr.shape[-1]).sum(axis=0)
Out[172]: array([2.85994792, 2.8922732 , 2.29051163, 2.77275709])
like image 30
Warren Weckesser Avatar answered Oct 05 '22 23:10

Warren Weckesser