Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Compute commutation matrix in numpy/scipy efficiently

I am trying to compute the commutation matrix in python for a large dataset. I wrote the following code but found it performs terribly (and runs into memory errors for examples of around 500 by 500). In my code a and b are equivilent to the m and n notation in the linked wikipedia page. Can anyone provide a quicker and more memory efficient alternative to my current attempt?

def vec(matrix):

    #Return vectorised matrix
    return(matrix.transpose().reshape(matrix.shape[0]*matrix.shape[1],1))


def commutation(a, b):

    # Example matrix with unique elements
    m = np.arange(a*b).reshape(a,b)

    # Vec(m) 
    vecm = vec(m)
    vecm = vecm.reshape(vecm.shape[0])

    # Get row inds
    rowInds = np.arange(a*b)

    # Get column inds
    colInds = np.argsort(vecm)
    colInds = colInds.reshape(colInds.shape[0])

    # Work out mapping between them.
    K = scipy.sparse.csr_matrix((np.ones(a*b),(rowInds,colInds)))

    return(K)
like image 810
JDoe2 Avatar asked Dec 08 '25 08:12

JDoe2


1 Answers

Below is an improved version of your code:

import numpy as np
from scipy.sparse import csr_matrix

def vec(A):
    m, n = A.shape[0], A.shape[1]
    return A.reshape(m*n, order='F')

def commutation_matrix_sp(A):
    m, n = A.shape[0], A.shape[1]
    row  = np.arange(m*n)
    col  = row.reshape((m, n), order='F').ravel()
    data = np.ones(m*n, dtype=np.int8)
    K = csr_matrix((data, (row, col)), shape=(m*n, m*n))
    return K

Test:

A = np.random.rand(500, 500)
K = commutation_matrix_sp(A)

print(f'{K.data.nbytes/2**20:.2f} MB')
# 0.24 MB

print(np.all(K @ vec(A) == vec(A.T)))
# True
like image 187
Andreas K. Avatar answered Dec 10 '25 21:12

Andreas K.