Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy element-wise dot product

Tags:

python

numpy

is there an elegant, numpy way to apply the dot product elementwise? Or how can the below code be translated into a nicer version?

m0 # shape (5, 3, 2, 2)
m1 # shape (5,    2, 2)
r = np.empty((5, 3, 2, 2))
for i in range(5):
    for j in range(3):
        r[i, j] = np.dot(m0[i, j], m1[i])

Thanks in advance!

like image 607
Philipp H. Avatar asked Jan 03 '17 12:01

Philipp H.


People also ask

How do you multiply element wise in Python?

multiply() technique will be used to do the element-wise multiplication of matrices in Python. The NumPy library's np. multiply(x1, x2) method receives two matrices as input and executes element-wise multiplication over them before returning the resultant matrix. We must send the two matrices as input to the np.

How do you multiply each element in a NumPy array?

If you work with numpy arrays, you can directly use the multiplication operator on the array to multiply each of its elements by a number. Copied! Multiplying a numpy array by a number effectively multiplies each element in the array by the specified number.

What does * do in NumPy?

NumPy performs operations element-by-element, so multiplying 2D arrays with * is not a matrix multiplication – it's an element-by-element multiplication.


1 Answers

Approach #1

Use np.einsum -

np.einsum('ijkl,ilm->ijkm',m0,m1)

Steps involved :

  • Keep the first axes from the inputs aligned.

  • Lose the last axis from m0 against second one from m1 in sum-reduction.

  • Let remaining axes from m0 and m1 spread-out/expand with elementwise multiplications in an outer-product fashion.


Approach #2

If you are looking for performance and with the axis of sum-reduction having a smaller length, you are better off with one-loop and using matrix-multiplication with np.tensordot, like so -

s0,s1,s2,s3 = m0.shape
s4 = m1.shape[-1]
r = np.empty((s0,s1,s2,s4))
for i in range(s0):
    r[i] = np.tensordot(m0[i],m1[i],axes=([2],[0]))

Approach #3

Now, np.dot could be efficiently used on 2D inputs for some further performance boost. So, with it, the modified version, though a bit longer one, but hopefully the most performant one would be -

s0,s1,s2,s3 = m0.shape
s4 = m1.shape[-1]
m0.shape = s0,s1*s2,s3   # Get m0 as 3D for temporary usage
r = np.empty((s0,s1*s2,s4))
for i in range(s0):
    r[i] = m0[i].dot(m1[i])
r.shape = s0,s1,s2,s4
m0.shape = s0,s1,s2,s3  # Put m0 back to 4D

Runtime test

Function definitions -

def original_app(m0, m1):
    s0,s1,s2,s3 = m0.shape
    s4 = m1.shape[-1]
    r = np.empty((s0,s1,s2,s4))
    for i in range(s0):
        for j in range(s1):
            r[i, j] = np.dot(m0[i, j], m1[i])
    return r

def einsum_app(m0, m1):
    return np.einsum('ijkl,ilm->ijkm',m0,m1)

def tensordot_app(m0, m1):
    s0,s1,s2,s3 = m0.shape
    s4 = m1.shape[-1]
    r = np.empty((s0,s1,s2,s4))
    for i in range(s0):
        r[i] = np.tensordot(m0[i],m1[i],axes=([2],[0]))
    return r        

def dot_app(m0, m1):
    s0,s1,s2,s3 = m0.shape
    s4 = m1.shape[-1]
    m0.shape = s0,s1*s2,s3   # Get m0 as 3D for temporary usage
    r = np.empty((s0,s1*s2,s4))
    for i in range(s0):
        r[i] = m0[i].dot(m1[i])
    r.shape = s0,s1,s2,s4
    m0.shape = s0,s1,s2,s3  # Put m0 back to 4D
    return r

Timings and verification -

In [291]: # Inputs
     ...: m0 = np.random.rand(50,30,20,20)
     ...: m1 = np.random.rand(50,20,20)
     ...: 

In [292]: out1 = original_app(m0, m1)
     ...: out2 = einsum_app(m0, m1)
     ...: out3 = tensordot_app(m0, m1)
     ...: out4 = dot_app(m0, m1)
     ...: 
     ...: print np.allclose(out1, out2)
     ...: print np.allclose(out1, out3)
     ...: print np.allclose(out1, out4)
     ...: 
True
True
True

In [293]: %timeit original_app(m0, m1)
     ...: %timeit einsum_app(m0, m1)
     ...: %timeit tensordot_app(m0, m1)
     ...: %timeit dot_app(m0, m1)
     ...: 
100 loops, best of 3: 10.3 ms per loop
10 loops, best of 3: 31.3 ms per loop
100 loops, best of 3: 5.12 ms per loop
100 loops, best of 3: 4.06 ms per loop
like image 189
Divakar Avatar answered Oct 06 '22 11:10

Divakar