Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient Kronecker product with identity matrix and regular matrix - NumPy/ Python

Tags:

I am working on a python project and making use of numpy. I frequently have to compute Kronecker products of matrices by the identity matrix. These are a pretty big bottleneck in my code so I would like to optimize them. There are two kinds of products I have to take. The first one is:

np.kron(np.eye(N), A)

This one is pretty easy to optimize by simply using scipy.linalg.block_diag. The product is equivalent to:

la.block_diag(*[A]*N)

Which is about 10 times faster. However, I am unsure on how to optimize the second kind of product:

np.kron(A, np.eye(N))

Is there a similar trick I can use?

like image 207
user3930598 Avatar asked Jun 09 '17 15:06

user3930598


1 Answers

One approach would be to initialize an output array of 4D and then assign values into it from A. Such an assignment would broadcast values and this is where we would get efficiency in NumPy.

Thus, a solution would be like so -

# Get shape of A
m,n = A.shape

# Initialize output array as 4D
out = np.zeros((m,N,n,N))

# Get range array for indexing into the second and fourth axes 
r = np.arange(N)

# Index into the second and fourth axes and selecting all elements along
# the rest to assign values from A. The values are broadcasted.
out[:,r,:,r] = A

# Finally reshape back to 2D
out.shape = (m*N,n*N)

Put as a function -

def kron_A_N(A, N):  # Simulates np.kron(A, np.eye(N))
    m,n = A.shape
    out = np.zeros((m,N,n,N),dtype=A.dtype)
    r = np.arange(N)
    out[:,r,:,r] = A
    out.shape = (m*N,n*N)
    return out

To simulate np.kron(np.eye(N), A), simply swap the operations along the first and second and similarly for third and fourth axes -

def kron_N_A(A, N):  # Simulates np.kron(np.eye(N), A)
    m,n = A.shape
    out = np.zeros((N,m,N,n),dtype=A.dtype)
    r = np.arange(N)
    out[r,:,r,:] = A
    out.shape = (m*N,n*N)
    return out

Timings -

In [174]: N = 100
     ...: A = np.random.rand(100,100)
     ...: 

In [175]: np.allclose(np.kron(A, np.eye(N)), kron_A_N(A,N))
Out[175]: True

In [176]: %timeit np.kron(A, np.eye(N))
1 loops, best of 3: 458 ms per loop

In [177]: %timeit kron_A_N(A, N)
10 loops, best of 3: 58.4 ms per loop

In [178]: 458/58.4
Out[178]: 7.842465753424658
like image 57
Divakar Avatar answered Sep 30 '22 01:09

Divakar