Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy multidimensional (3d) matrix multiplication

I get two 3d matrix A (32x3x3) and B(32x3x3), and I want to get matrix C with dimension 32x3x3. The calculation can be done using loop like:

a = numpy.random.rand(32, 3, 3)
b = numpy.random.rand(32, 3, 3)
c = numpy.random.rand(32, 3, 3)

for i in range(32):
    c[i] = numpy.dot(a[i], b[i])

I believe there must be a more efficient one-line solution to this problem. Can anybody help, thanks.

like image 880
user1459581 Avatar asked Oct 02 '16 22:10

user1459581


People also ask

How do you multiply multidimensional matrices?

multidimensional matrices are as follows: 1. The number of elements in the second dimension being multiplied in the first multidimensional matrix must equal the number of elements in the first dimension being multiplied of the second multidimensional matrix. That is, Ndb(A) = Nda(B).

Can you multiply 3D matrices?

A 3D matrix is nothing but a collection (or a stack) of many 2D matrices, just like how a 2D matrix is a collection/stack of many 1D vectors. So, matrix multiplication of 3D matrices involves multiple multiplications of 2D matrices, which eventually boils down to a dot product between their row/column vectors.

How do you perform matrix multiplication on the NumPy arrays?

To multiply two matrices use the dot() function of NumPy. It takes only 2 arguments and returns the product of two matrices.


1 Answers

You could do this using np.einsum:

In [142]: old = orig(a,b)

In [143]: new = np.einsum('ijk,ikl->ijl', a, b)

In [144]: np.allclose(old, new)
Out[144]: True

One advantage of using einsum is that you can almost read off what it's doing from the indices: leave the first axis alone (i), and perform a matrix multiplication on the last two (jk,kl->jl)).

like image 184
DSM Avatar answered Oct 18 '22 04:10

DSM