Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dot product for arbitrary shaped NumPy arrays

Given two numpy.ndarray objects, A and B, of arbitrary shapes, I'd like to compute an numpy.ndarray C with the property that C[i] == np.dot(A[i], B[i]) for all i. How can I do this?

Example 1: A.shape==(2,3,4) and B.shape==(2,4,5), then we should have C.shape==(2,3,5).

Example 2: A.shape==(2,3,4) and B.shape==(2,4), then we should have C.shape==(2,3).

like image 968
dshin Avatar asked Nov 28 '25 07:11

dshin


1 Answers

Here's a generic solution to cover all kinds of cases / arbitrary shapes using some reshaping and np.einsum. einsum helps here as we need alignment along the first axis and reduction along the last axes of the input arrays. The implementation would look something like this -

def dotprod_axis0(A,B):
    N,nA,nB = A.shape[0], A.shape[-1], B.shape[1]
    Ar = A.reshape(N,-1,nA)
    Br = B.reshape(N,nB,-1)
    return np.squeeze(np.einsum('ijk,ikl->ijl',Ar,Br))

Cases

I. A : 2D, B : 2D

In [119]: # Inputs
     ...: A = np.random.randint(0,9,(3,4))
     ...: B = np.random.randint(0,9,(3,4))
     ...: 

In [120]: for i in range(A.shape[0]):
     ...:     print np.dot(A[i], B[i])
     ...:     
33
86
48

In [121]: dotprod_axis0(A,B)
Out[121]: array([33, 86, 48])

II. A : 3D, B : 3D

In [122]: # Inputs
     ...: A = np.random.randint(0,9,(2,3,4))
     ...: B = np.random.randint(0,9,(2,4,5))
     ...: 

In [123]: for i in range(A.shape[0]):
     ...:     print np.dot(A[i], B[i])
     ...:     
[[ 74  70  53 118  43]
 [ 47  43  29  95  30]
 [ 41  37  26  23  15]]
[[ 50  86  33  35  82]
 [ 78 126  40 124 140]
 [ 67  88  35  47  83]]

In [124]: dotprod_axis0(A,B)
Out[124]: 
array([[[ 74,  70,  53, 118,  43],
        [ 47,  43,  29,  95,  30],
        [ 41,  37,  26,  23,  15]],

       [[ 50,  86,  33,  35,  82],
        [ 78, 126,  40, 124, 140],
        [ 67,  88,  35,  47,  83]]])

III. A : 3D, B : 2D

In [125]: # Inputs
     ...: A = np.random.randint(0,9,(2,3,4))
     ...: B = np.random.randint(0,9,(2,4))
     ...: 

In [126]: for i in range(A.shape[0]):
     ...:     print np.dot(A[i], B[i])
     ...:     
[ 87 105  53]
[152 135 120]

In [127]: dotprod_axis0(A,B)
Out[127]: 
array([[ 87, 105,  53],
       [152, 135, 120]])

IV. A : 2D, B : 3D

In [128]: # Inputs
     ...: A = np.random.randint(0,9,(2,4))
     ...: B = np.random.randint(0,9,(2,4,5))
     ...: 

In [129]: for i in range(A.shape[0]):
     ...:     print np.dot(A[i], B[i])
     ...:     
[76 93 31 75 16]
[ 33  98  49 117 111]

In [130]: dotprod_axis0(A,B)
Out[130]: 
array([[ 76,  93,  31,  75,  16],
       [ 33,  98,  49, 117, 111]])
like image 153
Divakar Avatar answered Nov 29 '25 21:11

Divakar



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!