Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does numpy.dot behave in this way?

I'm trying to understand why numpy's dot function behaves as it does:

M = np.ones((9, 9))
V1 = np.ones((9,))
V2 = np.ones((9, 5))
V3 = np.ones((2, 9, 5))
V4 = np.ones((3, 2, 9, 5))

Now np.dot(M, V1) and np.dot(M, V2) behave as expected. But for V3 and V4 the result surprises me:

>>> np.dot(M, V3).shape
(9, 2, 5)
>>> np.dot(M, V4).shape
(9, 3, 2, 5)

I expected (2, 9, 5) and (3, 2, 9, 5) respectively. On the other hand, np.matmul does what I expect: the matrix multiply is broadcast over the first N - 2 dimensions of the second argument and the result has the same shape:

>>> np.matmul(M, V3).shape
(2, 9, 5)
>>> np.matmul(M, V4).shape
(3, 2, 9, 5)

So my question is this: what is the rationale for np.dot behaving as it does? Does it serve some particular purpose, or is it the result of applying some general rule?

like image 811
fadriaensen Avatar asked Nov 29 '15 11:11

fadriaensen


Video Answer


1 Answers

From the docs for np.dot:

For 2-D arrays it is equivalent to matrix multiplication, and for 1-D arrays to inner product of vectors (without complex conjugation). For N dimensions it is a sum product over the last axis of a and the second-to-last of b:

dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])

For np.dot(M, V3),

(9, 9), (2, 9, 5) --> (9, 2, 5)

For np.dot(M, V4),

(9, 9), (3, 2, 9, 5) --> (9, 3, 2, 5)

The strike-through represents dimensions that are summed over, and are therefore not present in the result.


In contrast, np.matmul treats N-dimensional arrays as 'stacks' of 2D matrices:

The behavior depends on the arguments in the following way.

  • If both arguments are 2-D they are multiplied like conventional matrices.
  • If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.

The same reductions are performed in both cases, but the order of the axes is different. np.matmul essentially does the equivalent of:

for ii in range(V3.shape[0]):
    out1[ii, :, :] = np.dot(M[:, :], V3[ii, :, :])

and

for ii in range(V4.shape[0]):
    for jj in range(V4.shape[1]):
        out2[ii, jj, :, :] = np.dot(M[:, :], V4[ii, jj, :, :])
like image 134
ali_m Avatar answered Nov 02 '22 22:11

ali_m