Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

multiple numpy dot products without a loop

Is it possible to compute several dot products without a loop? say you have the following:

a = randn(100, 3, 3)
b = randn(100, 3, 3)

I want to get an array z of shape (100, 3, 3) such that for all i

z[i, ...] == dot(a[i, ...], b[i, ...])

in other words, which verifies:

for va, vb, vz in izip(a, b, z):
    assert (vq == dot(va, vb)).all()

The straightforward solution would be:

z = array([dot(va, vb) for va, vb in zip(a, b)])

which uses an implicit loop (list comprehension + array).

Is there a more efficient way to compute z?

like image 898
user3716510 Avatar asked Jun 06 '14 21:06

user3716510


2 Answers

Try Einstein summation in numpy:

z = np.einsum('...ij,...jk->...ik', a, b)

It's elegant and does not require you to write a loop, as you requested. It gives me a factor of 4.8 speed increase on my system:

%timeit z = array([dot(va, vb) for va, vb in zip(a, b)])
1000 loops, best of 3: 454 µs per loop

%timeit z = np.einsum('...ij,...jk->...ik', a, b)
10000 loops, best of 3: 94.6 µs per loop
like image 26
Oliver W. Avatar answered Oct 23 '22 23:10

Oliver W.


np.einsum can be useful here. Try running this copy+pasteable code:

import numpy as np

a = np.random.randn(100, 3, 3)
b = np.random.randn(100, 3, 3)

z = np.einsum("ijk, ikl -> ijl", a, b)

z2 = np.array([ai.dot(bi) for ai, bi in zip(a, b)])

assert (z == z2).all()

einsum is compiled code and runs very fast, even compared to np.tensordot (which doesn't apply here exactly, but often is applicable). Here are some stats:

In [8]: %timeit z = np.einsum("ijk, ikl -> ijl", a, b)
10000 loops, best of 3: 105 us per loop


In [9]: %timeit z2 = np.array([ai.dot(bi) for ai, bi in zip(a, b)])
1000 loops, best of 3: 1.06 ms per loop
like image 197
eickenberg Avatar answered Oct 23 '22 23:10

eickenberg