Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the best way to compute the trace of a matrix product in numpy?

If I have numpy arrays A and B, then I can compute the trace of their matrix product with:

tr = numpy.linalg.trace(A.dot(B))

However, the matrix multiplication A.dot(B) unnecessarily computes all of the off-diagonal entries in the matrix product, when only the diagonal elements are used in the trace. Instead, I could do something like:

tr = 0.0
for i in range(n):
    tr += A[i, :].dot(B[:, i])

but this performs the loop in Python code and isn't as obvious as numpy.linalg.trace.

Is there a better way to compute the trace of a matrix product of numpy arrays? What is the fastest or most idiomatic way to do this?

like image 891
amcnabb Avatar asked Sep 17 '13 15:09

amcnabb


People also ask

Which NumPy algorithm does matrix multiplication use?

Numpy use a BLAS library internally. It is OpenBLAS by default on most platforms (based on GotoBLAS). OpenBLAS does not use the Strassen algorithm based on my understanding of the code (same for the BLIS library). Strassen can be faster for very large matrices if implemented very carefully.


1 Answers

You can improve on @Bill's solution by reducing intermediate storage to the diagonal elements only:

from numpy.core.umath_tests import inner1d

m, n = 1000, 500

a = np.random.rand(m, n)
b = np.random.rand(n, m)

# They all should give the same result
print np.trace(a.dot(b))
print np.sum(a*b.T)
print np.sum(inner1d(a, b.T))

%timeit np.trace(a.dot(b))
10 loops, best of 3: 34.7 ms per loop

%timeit np.sum(a*b.T)
100 loops, best of 3: 4.85 ms per loop

%timeit np.sum(inner1d(a, b.T))
1000 loops, best of 3: 1.83 ms per loop

Another option is to use np.einsum and have no explicit intermediate storage at all:

# Will print the same as the others:
print np.einsum('ij,ji->', a, b)

On my system it runs slightly slower than using inner1d, but it may not hold for all systems, see this question:

%timeit np.einsum('ij,ji->', a, b)
100 loops, best of 3: 1.91 ms per loop
like image 63
Jaime Avatar answered Sep 19 '22 23:09

Jaime