Logo Questions Linux Laravel Mysql Ubuntu Git Menu

What is the role of keepdims in Numpy (Python)?




When I use np.sum, I encountered a parameter called keepdims. After looking up the docs, I still cannot understand the meaning of keepdims.

keepdims: bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original arr.

I will appreciate it if anyone can make some sense of this with a simple example.

like image 664
watermelon Avatar asked Dec 02 '16 07:12


1 Answers

Consider a small 2d array:

In [180]: A=np.arange(12).reshape(3,4)
In [181]: A
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Sum across rows; the result is a (3,) array

In [182]: A.sum(axis=1)
Out[182]: array([ 6, 22, 38])

But to sum (or divide) A by the sum requires reshaping

In [183]: A-A.sum(axis=1)
ValueError: operands could not be broadcast together with shapes (3,4) (3,) 
In [184]: A-A.sum(axis=1)[:,None]   # turn sum into (3,1)
array([[ -6,  -5,  -4,  -3],
       [-18, -17, -16, -15],
       [-30, -29, -28, -27]])

If I use keepdims, "the result will broadcast correctly against" A.

In [185]: A.sum(axis=1, keepdims=True)   # (3,1) array
array([[ 6],
In [186]: A-A.sum(axis=1, keepdims=True)
array([[ -6,  -5,  -4,  -3],
       [-18, -17, -16, -15],
       [-30, -29, -28, -27]])

If I sum the other way, I don't need the keepdims. Broadcasting this sum is automatic: A.sum(axis=0)[None,:]. But there's no harm in using keepdims.

In [190]: A.sum(axis=0)
Out[190]: array([12, 15, 18, 21])    # (4,)
In [191]: A-A.sum(axis=0)
array([[-12, -14, -16, -18],
       [ -8, -10, -12, -14],
       [ -4,  -6,  -8, -10]])

If you prefer, these actions might make more sense with np.mean, normalizing the array over columns or rows. In any case it can simplify further math between the original array and the sum/mean.

like image 62
hpaulj Avatar answered Sep 21 '22 21:09
