Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to use Numpy - multi-dimensional sums and products

I have these variables with the following dimensions:

A   - (3,)
B   - (4,)
X_r - (3,K,N,nS)
X_u - (4,K,N,nS)
k - (K,)

and I want to compute (A.dot(X_r[:,:,n,s])*B.dot(X_u[:,:,n,s])).dot(k) for every possible n and s, the way I am doing it now is the following:

np.array([[(A.dot(X_r[:,:,n,s])*B.dot(X_u[:,:,n,s])).dot(k) for n in xrange(N)] for s in xrange(nS)]) #nSxN

But this is super slow and I was wondering if there was a better way of doing it but I am not sure.

However there is another computation that I am doing and I am sure it can be optimized:

np.sum(np.array([(X_r[:,:,n,s]*B.dot(X_u[:,:,n,s])).dot(k) for n in xrange(N)]),axis=0)

In this one I am creating a numpy array just to sum it in one axis and discard the array after. If this was a list in 1-D I would use reduce and optimize it, what should I use for numpy arrays?

like image 436
João Abrantes Avatar asked May 17 '15 09:05

João Abrantes


2 Answers

Using few np.einsum calls -

# Calculation of A.dot(X_r[:,:,n,s])
p1 = np.einsum('i,ijkl->jkl',A,X_r)

# Calculation of B.dot(X_u[:,:,n,s])
p2 = np.einsum('i,ijkl->jkl',B,X_u)

# Include .dot(k) part to get the final output
out = np.einsum('ijk,i->kj',p1*p2,k)

About the second example, this solves it:

p1 = np.einsum('i,ijkl->jkl',B,X_u)#OUT_DIM - (k,N,nS)
sol = np.einsum('ijkl,j->il',X_r*p1[None,:,:,:],k)#OUT_DIM (3,nS)
like image 138
Divakar Avatar answered Oct 27 '22 10:10

Divakar


You can use dot for multiplication of matrices in higher dimensions but the running indices must be the last two. When we reorder your matrices

X_r_t = X_r.transpose(2,3,0,1)
X_u_t = X_u.transpose(2,3,0,1)

we obtain for your first expression

res1_imp = (A.dot(X_r_t)*B.dot(X_u_t)).dot(k).T # shape nS x N

and for the second expression

res2_imp = np.sum((X_r_t * B.dot(X_u_t)[:,:,None,:]).dot(k),axis=0)[-1]

Timings

Divakars solution gives on my computer 10000 loops, best of 3: 21.7 µs per loop

my solution gives 10000 loops, best of 3: 101 µs per loop

Edit

My upper Timings included the computation of both expressions. When I include only the first expression (as Divakar) I obtain 10000 loops, best of 3: 41 µs per loop ... which is still slower but closer to his timings

like image 24
plonser Avatar answered Oct 27 '22 11:10

plonser