Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient reduction of multiple tensors in Python

I have four multidimensional tensors v[i,j,k], a[i,s,l], w[j,s,t,m], x[k,t,n] in Numpy, and I am trying to compute the tensor z[l,m,n] given by:

z[l,m,n] = sum_{i,j,k,s,t} v[i,j,k] * a[i,s,l] * w[j,s,t,m] * x[k,t,n]

All the tensors are relatively small (say less that 32k elements in total), however I need to perform this computation many times, so I would like the function to have as little overhead as possible.

I tried to implement it using numpy.einsum like this:

z = np.einsum('ijk,isl,jstm,ktn', v, a, w, x)

but it was very slow. I also tried the following sequence of numpy.tensordot calls:

z = np.zeros((a.shape[-1],w.shape[-1],x.shape[-1]))
for s in range(a.shape[1]):
  for t in range(x.shape[1]):
    res = np.tensordot(v, a[:,s,:], (0,0))
    res = np.tensordot(res, w[:,s,t,:], (0,0))
    z += np.tensordot(res, x[:,s,:], (0,0))

inside of a double for loop to sum over s and t (both s and t are very small, so that is not too much of a problem). This worked much better, but it is still not as fast as I would expect. I think this may be because of all the operations that tensordot needs to perform internally before taking the actual product (e.g. permuting the axes).

I was wondering if there is a more efficient way to implement this kind of operations in Numpy. I also wouldn't mind implementing this part in Cython, but I'm not sure what would be the right algorithm to use.

like image 932
Alessandro Avatar asked Mar 07 '16 07:03

Alessandro


1 Answers

Using np.tensordot in parts, you can vectorize things like so -

# Perform "np.einsum('ijk,isl->jksl', v, a)"
p1 = np.tensordot(v,a,axes=([0],[0]))         # shape = jksl

# Perform "np.einsum('jksl,jstm->kltm', p1, w)"
p2 = np.tensordot(p1,w,axes=([0,2],[0,1]))    # shape = kltm

# Perform "np.einsum('kltm,ktn->lmn', p2, w)"
z = np.tensordot(p2,x,axes=([0,2],[0,1]))     # shape = lmn

Runtime test and verify output -

In [15]: def einsum_based(v, a, w, x):
    ...:     return np.einsum('ijk,isl,jstm,ktn', v, a, w, x) # (l,m,n)
    ...: 
    ...: def vectorized_tdot(v, a, w, x):
    ...:     p1 = np.tensordot(v,a,axes=([0],[0]))        # shape = jksl
    ...:     p2 = np.tensordot(p1,w,axes=([0,2],[0,1]))   # shape = kltm
    ...:     return np.tensordot(p2,x,axes=([0,2],[0,1])) # shape = lmn
    ...: 

Case #1 :

In [16]: # Input params
    ...: i,j,k,l,m,n = 10,10,10,10,10,10
    ...: s,t = 3,3 # As problem states : "both s and t are very small".
    ...: 
    ...: # Input arrays
    ...: v = np.random.rand(i,j,k)
    ...: a = np.random.rand(i,s,l)
    ...: w = np.random.rand(j,s,t,m)
    ...: x = np.random.rand(k,t,n)
    ...: 

In [17]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x))
Out[17]: True

In [18]: %timeit einsum_based(v,a,w,x)
10 loops, best of 3: 129 ms per loop

In [19]: %timeit vectorized_tdot(v,a,w,x)
1000 loops, best of 3: 397 µs per loop

Case #2 (Bigger datasizes) :

In [20]: # Input params
    ...: i,j,k,l,m,n = 15,15,15,15,15,15
    ...: s,t = 3,3 # As problem states : "both s and t are very small".
    ...: 
    ...: # Input arrays
    ...: v = np.random.rand(i,j,k)
    ...: a = np.random.rand(i,s,l)
    ...: w = np.random.rand(j,s,t,m)
    ...: x = np.random.rand(k,t,n)
    ...: 

In [21]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x))
Out[21]: True

In [22]: %timeit einsum_based(v,a,w,x)
1 loops, best of 3: 1.35 s per loop

In [23]: %timeit vectorized_tdot(v,a,w,x)
1000 loops, best of 3: 1.52 ms per loop
like image 188
Divakar Avatar answered Oct 21 '22 17:10

Divakar