Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simplifying double einsum

I'm trying to use numpy.einsum to simplify a loop I have in my code.

Currently, my code looks something like this:

k = 100 
m = 50
n = 10
A = np.arange(k*m*n).reshape(k, m, n)
B = np.arange(m*m).reshape(m, m)

T = np.zeros((n, n)) 
for ind in xrange(k):
    T += np.dot(A[ind,:,:].T, np.dot(B, A[ind,:,:]))

I'm trying to use numpy.einsum as an alternative to this loop:

Tp = np.einsum('nij,njk->ik', np.einsum('nij,jk->nik', A.transpose(0,2,1), B), A)
print np.allclose(T, Tp)

Is it possible to use a single numpy.einsum instead of two?

like image 850
Javier C. Avatar asked Mar 14 '23 08:03

Javier C.


1 Answers

On my PC your timing would be:

np.einsum('nij,njk->ik', np.einsum('nij,jk->nik', A.transpose(0,2,1), B), A)
# 100 loops, best of 3: 4.55 ms per loop

You can achieve that with:

T2 = np.einsum('nij, il, kln ->jk', A, B, A.T)
#  10 loops, best of 3: 51.9 ms per loop

or using a double np.tensordot():

T3 = np.tensordot(A, np.tensordot(A, B, axes=(1, 1)), axes=((0, 1), (0, 2)))
# 100 loops, best of 3: 2.73 ms per loop

My conclusion is that you are getting a better performance doing this operation in two steps. This is probably due to the bigger strides that happen when the operation is performed at once, possibly causing more cache losses.

like image 113
Saullo G. P. Castro Avatar answered Mar 24 '23 18:03

Saullo G. P. Castro