I need to perform matrix multiplication on two 4D arrays (m & n) with dimensions of 2x2x2x2 and 2x3x2x2 for m & n respectively, which should result in a 2x3x2x2 array. After a lot of research (mostly on this site) it seems this can be done efficiently with either np.einsum or np.tensordot, but I am unable to replicate the answer I am getting from Matlab (verified by hand). I understand how these methods (einsum and tensordot) work when performing matrix multiplication on 2D arrays (clearly explained here), but I cannot get the axes indexes correct for the 4D arrays. Clearly I’m missing something! My actual problem deals with two 23x23x3x3 arrays of complex numbers but my test arrays are:
a = np.array([[1, 7], [4, 3]]) 
b = np.array([[2, 9], [4, 5]]) 
c = np.array([[3, 6], [1, 0]]) 
d = np.array([[2, 8], [1, 2]]) 
e = np.array([[0, 0], [1, 2]])
f = np.array([[2, 8], [1, 0]])
m = np.array([[a, b], [c, d]])              # (2,2,2,2)
n = np.array([[e, f, a], [b, d, c]])        # (2,3,2,2)
I realise the complex numbers may present further issues, but for now, I am just trying to understand how the indexxing works with einsum & tensordot. The answer I’m chasing is this 2x3x2x2 array:
+----+-----------+-----------+-----------+
|    | 0         | 1         | 2         |
+====+===========+===========+===========+
|  0 | [[47 77]  | [[22 42]  | [[44 40]  |
|    |  [31 67]] |  [27 74]] |  [33 61]] |
+----+-----------+-----------+-----------+
|  1 | [[42 70]  | [[24 56]  | [[41 51]  |
|    |  [10 19]] |  [ 6 20]] |  [ 6 13]] |
+----+-----------+-----------+-----------+
and my closest attempt is by using np.tensordot:
mn = np.tensordot(m,n, axes=([1,3],[0,2]))
which gives me a 2x2x3x2 array with correct numbers but not in the right order:
+----+-----------+-----------+
|    | 0         | 1         |
+====+===========+===========+
|  0 | [[47 77]  | [[31 67]  |
|    |  [22 42]  |  [24 74]  |
|    |  [44 40]] |  [33 61]] |
+----+-----------+-----------+
|  1 | [[42 70]  | [[10 19]  |
|    |  [24 56]  |  [ 6 20]  |
|    |  [41 51]] |  [ 6 13]] |
+----+-----------+-----------+
I’ve also tried to implement some of the solutions from here but have not had any luck.
Any ideas on how I might improve this would be greatly appreciated, thanks
You could simply swap the axes on the tensordot result, so that we would still leverage BLAS based sum-reductions with tensordot -
np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
Alternatively, we could swap the positions of m and n in the tensordot call and transpose to re-arrange all the axes -
np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
Using the manually labor of reshaping and swapping axes, we can bring in 2D matrix multiplication with np.dot as well, like so -
m0,m1,m2,m3 = m.shape
n0,n1,n2,n3 = n.shape
m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
Runtime test -
Scaling the input arrays to 10x shapes :
In [85]: m = np.random.rand(20,20,20,20)
In [86]: n = np.random.rand(20,30,20,20)
# @Daniel F's soln with einsum
In [87]: %timeit np.einsum('ijkl,jmln->imkn', m, n)
10 loops, best of 3: 136 ms per loop
In [126]: %timeit np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
100 loops, best of 3: 2.31 ms per loop
In [127]: %timeit np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
100 loops, best of 3: 2.37 ms per loop
In [128]: %%timeit
     ...: m0,m1,m2,m3 = m.shape
     ...: n0,n1,n2,n3 = n.shape
     ...: m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
     ...: n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
     ...: out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
100 loops, best of 3: 2.36 ms per loop
                        Your best bet, since your reduction dimensions neither match (which would allow broadcasting) nor are the "inner" dimensions (which would work natively with np.tensordot) is to use np.einsum
np.einsum('ijkl,jmln->imkn', m, n)
array([[[[47, 77],
         [31, 67]],
        [[22, 42],
         [24, 74]],
        [[44, 40],
         [33, 61]]],
       [[[42, 70],
         [10, 19]],
        [[24, 56],
         [ 6, 20]],
        [[41, 51],
         [ 6, 13]]]])
                        Just to demonstrate that broadcasting also works:
(m[:, :, None, :, :, None] * n[None, :, :, None, :, :]).sum(axis=(1,4))
But the other solutions posted are probably faster, at least for large arrays.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With