Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python, numpy, einsum multiply a stack of matrices

For performance reasons,

I'm curious if there is a way to multiply a stack of a stack of matrices. I have a 4-D array (500, 201, 2, 2). Its basically a 500 length stack of (201,2,2) matrices where for each of the 500, I want to multiply the adjacent matrices using einsum and get another (201,2,2) matrix.

I am only doing matrix multiplication on the [2x2] matrices at the end. Since my explanation is already heading off the rails, I'll just show what I'm doing now, and also the 'reduce' equivalent and why its not helpful (because its the same speed computationally). Preferably this would be a numpy one-liner, but I don't know what that is, or even if its possible.

Code:

Arr = rand(500,201,2,2)

def loopMult(Arr):
    ArrMult = Arr[0]
    for i in range(1,len(Arr)):
        ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
    return ArrMult

def myeinsum(A1, A2):
    return np.einsum('fij,fjk->fik', A1, A2)

A1 = loopMult(Arr)
A2 = reduce(myeinsum, Arr)
print np.all(A1 == A2)

print shape(A1); print shape(A2)

%timeit loopMult(Arr)
%timeit reduce(myeinsum, Arr)

Returns:

True
(201, 2, 2)
(201, 2, 2)
10 loops, best of 3: 34.8 ms per loop
10 loops, best of 3: 35.2 ms per loop

Any help would be appreciated. Things are functional, but when I have to iterate this over a large series of parameters, the code tends to take a long time, and I'm wondering if there's a way to avoid the 500 iterations through a loop.

like image 242
Vince W. Avatar asked Sep 22 '14 18:09

Vince W.


1 Answers

I don't think it's possible to do this efficiently using numpy (the cumprod solution was elegant, though). This is the sort of situation where I would use f2py. It's the simplest way of calling a faster language that I know of and only requires a single extra file.

fortran.f90:

subroutine multimul(a, b)
  implicit none
  real(8), intent(in)  :: a(:,:,:,:)
  real(8), intent(out) :: b(size(a,1),size(a,2),size(a,3))
  real(8) :: work(size(a,1),size(a,2))
  integer i, j, k, l, m
  !$omp parallel do private(work,i,j)
  do i = 1, size(b,3)
    b(:,:,i) = a(:,:,i,size(a,4)) 
    do j = size(a,4)-1, 1, -1
      work = matmul(b(:,:,i),a(:,:,i,j))
      b(:,:,i) = work
    end do
  end do
end subroutine

Compile with f2py -c -m fortran fortran.f90 (or F90FLAGS="-fopenmp" f2py -c -m fortran fortran.f90 -lgomp to enable OpenMP acceleration). Then you would use it in your script as

import numpy as np, fmuls
Arr = np.random.standard_normal([500,201,2,2])
def loopMult(Arr):
  ArrMult = Arr[0]
  for i in range(1,len(Arr)):
    ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
  return ArrMult
def myeinsum(A1, A2):
  return np.einsum('fij,fjk->fik', A1, A2)
A1 = loopMult(Arr)
A2 = reduce(myeinsum, Arr)
A3 = fmuls.multimul(Arr.T).T
print np.allclose(A1,A2)
print np.allclose(A1,A3)
%timeit loopMult(Arr)
%timeit reduce(myeinsum, Arr)
%timeit fmuls.multimul(Arr.T).T

Which outputs

True
True
10 loops, best of 3: 48.4 ms per loop
10 loops, best of 3: 48.8 ms per loop
100 loops, best of 3: 5.82 ms per loop

So that's a factor 8 speedup. The reason for all the transposes is that f2py implicitly transposes all the arrays, and we need to transpose them manually to tell it that our fortran code expects things to be transposed. This avoids a copy operation. The cost is that each of our 2x2 matrices are transposed, so to avoid performing the wrong operation we have to loop in reverse.

Greater speedups than 8 should be possible - I didn't spend any time trying to optimize this.

like image 180
amaurea Avatar answered Oct 15 '22 10:10

amaurea