Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Right multiplication of a dense array with a sparse matrix

If I have a numpy.ndarray A and a scipy.sparse.csc_matrix B, how do I take A dot B? I can do B dot A by saying B.dot(A), but the other way I can only think of this:

B.T.dot(A.T).T

Is there a more direct method to do this?

like image 985
shaoyl85 Avatar asked Oct 20 '22 13:10

shaoyl85


1 Answers

Your question initially confused me, since for my version of scipy, A.dot(B) and np.dot(A, B) both work fine; the .dot method of the sparse matrix simply overrides np.dot. However it seems that this feature was added in this pull request, and is not present in versions of scipy older than 0.14.0. I'm guessing that you have one of these older versions.

Here's some test data:

import numpy as np
from scipy import sparse

A = np.random.randn(1000, 2000)
B = sparse.rand(2000, 3000, format='csr')

For versions of scipy >= 0.14.0, you can simply use:

C = A.dot(B)
C = np.dot(A, B)

For versions < 0.14.0, both of these will raise a ValueError:

In [6]: C = A.dot(B)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-7fbaa337fd94> in <module>()
----> 1 C = A.dot(B)

ValueError: Cannot find a common data type.

Instead, you could use one of:

# your original solution
%timeit B.T.dot(A.T).T
# 10 loops, best of 3: 93.1 ms per loop

# post-multiply A by B
%timeit B.__rmul__(A)
# 10 loops, best of 3: 91.9 ms per loop

As you can see there's basically no performance difference, although I personally think the second version is more readable.


Update:

As @shaoyl85 just pointed out, one can just use the * operator rather than calling the __rmul__() method directly:

# equivalent to B.__rmul__(A)
C = A * B

It seems that matrices have a higher priority when determining the behavior of the * operator than ndarrays. This is a potential gotcha for those of us who are more used to the behavior of ndarrays (where * means elementwise multiplication).

like image 186
ali_m Avatar answered Oct 22 '22 09:10

ali_m