Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matrix multiply every pair of 2-D arrays along first dimension with einsum

Tags:

python

numpy

I have two 3-D arrays of the same size a and b

np.random.seed([3,14159])
a = np.random.randint(10, size=(4, 3, 2))
b = np.random.randint(10, size=(4, 3, 2))

print(a)

[[[4 8]
  [1 1]
  [9 2]]

 [[8 1]
  [4 2]
  [8 2]]

 [[8 4]
  [9 4]
  [3 4]]

 [[1 5]
  [1 2]
  [6 2]]]

print(b)

[[[7 7]
  [1 1]
  [7 8]]

 [[7 4]
  [8 0]
  [0 9]]

 [[3 8]
  [7 7]
  [2 6]]

 [[3 1]
  [9 3]
  [0 5]]]

I want to take the first array from a

a[0]

[[4 8]
 [1 1]
 [9 2]]

And the first one from b

b[0]

[[7 7]
 [1 1]
 [7 8]]

And return this

a[0].T.dot(b[0])

[[ 92 101]
 [ 71  73]]

But I want to do this over the entire first dimension. I thought I could use np.einsum

np.einsum('abc,ade->ace', a, b)

[[[210 224]
  [165 176]]

 [[300 260]
  [ 75  65]]

 [[240 420]
  [144 252]]

 [[ 96  72]
  [108  81]]]

This is the correct shape, but not values.

I expect to get this:

np.array([x.T.dot(y).tolist() for x, y in zip(a, b)])

[[[ 92 101]
  [ 71  73]]

 [[ 88 104]
  [ 23  22]]

 [[ 93 145]
  [ 48  84]]

 [[ 12  34]
  [ 33  21]]]
like image 718
piRSquared Avatar asked Oct 11 '25 07:10

piRSquared


1 Answers

The matrix multiplication amounts to a sum of products where the sum is taken over the middle axis, so the index b should be the same for both arrays: (i.e. change ade to abe):

In [40]: np.einsum('abc,abe->ace', a, b)
Out[40]: 
array([[[ 92, 101],
        [ 71,  73]],

       [[ 88, 104],
        [ 23,  22]],

       [[ 93, 145],
        [ 48,  84]],

       [[ 12,  34],
        [ 33,  21]]])

When the input arrays have index subscripts that are missing in the output array, they are summed over independently. That is,

np.einsum('abc,ade->ace', a, b)

is equivalent to

In [44]: np.einsum('abc,ade->acebd', a, b).sum(axis=-1).sum(axis=-1)
Out[44]: 
array([[[210, 224],
        [165, 176]],

       [[300, 260],
        [ 75,  65]],

       [[240, 420],
        [144, 252]],

       [[ 96,  72],
        [108,  81]]])
like image 181
unutbu Avatar answered Oct 16 '25 12:10

unutbu



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!