Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Alternatives to numpy einsum

When I calculate third order moments of an matrix X with N rows and n columns, I usually use einsum:

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N

This works usually fine, but now I am working with bigger values, namely n = 120 and N = 100000, and einsum returns the following error:

ValueError: iterator is too large

The alternative of doing 3 nested loops is unfeasable, so I am wondering if there is any kind of alternative.

like image 909
Ulderique Demoitre Avatar asked Jul 15 '16 13:07

Ulderique Demoitre

Video Answer

1 Answers

Note that calculating this will need to do at least ~n3 × N = 173 billion operations (not considering symmetry), so it will be slow unless numpy has access to GPU or something. On a modern computer with a ~3 GHz CPU, the whole computation is expected to take about 60 seconds to complete, assuming no SIMD/parallel speed up.

For testing, let's start with N = 1000. We will use this to check correctness and performance:

#!/usr/bin/env python3

import numpy
import time


n = 120
N = 1000
X = numpy.random.random((N, n))

start_time = time.time()

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)

end_time = time.time()

print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)

This takes about 8 seconds. This is the baseline.

Let's start with the 3 nested loop as the alternative:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds

This takes roughly half a minute, no good! One reason is because this is actually four nested loops: numpy.sum can also be considered a loop.

We note that the sum can be turned into a dot product to remove this 4th loop:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds

Much better now but still slow. But we note that the the dot product can be changed into a matrix multiplication to remove one loop:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds

Huh? Now this is even much more efficient than einsum! We could also check that the answer should indeed be correct.

Can we go further? Yes! We could eliminate the k loop by:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = numpy.repeat(X[:,j], n).reshape((N, n))
    M3[j] = (Y * X).T @ X
# ~0.3 seconds

We could also use broadcasting (i.e. a * [b,c] == [a*b, a*c] for each row of X) to avoid doing the numpy.repeat (thanks @Divakar):

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = X[:,j].reshape((N, 1))
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j]
    M3[j] = (Y * X).T @ X
# ~0.16 seconds

If we scale this to N = 100000 the program is expected to take 16 seconds, which is within the theoretical limit, so eliminating the j may not help too much (but that may make the code really hard to understand). We could accept this as the final solution.

Note: If you are using Python 2, a @ b is equivalent to a.dot(b).

like image 132
kennytm Avatar answered Dec 15 '22 21:12
