Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get unique rows from a Scipy sparse matrix

I'm working with sparse matrices in python, I wonder if there is an efficient way to remove duplicate rows in a sparse matrix, and have only the unique rows remain.

I did not find a function associated with it and not sure how to do it without converting the sparse matrix to dense and use numpy.unique.

like image 264
Dora Avatar asked Sep 09 '17 03:09

Dora


2 Answers

There is no quick way to do it, so I had to write a function. It returns a sparse matrix with the unique rows (axis=0) or columns (axis=1) of an input sparse matrix. Note that the unique rows or columns of the returned matrix are not lexicographical sorted (as is the case with the np.unique).

import numpy as np
import scipy.sparse as sp

def sp_unique(sp_matrix, axis=0):
    ''' Returns a sparse matrix with the unique rows (axis=0)
    or columns (axis=1) of an input sparse matrix sp_matrix'''
    if axis == 1:
        sp_matrix = sp_matrix.T

    old_format = sp_matrix.getformat()
    dt = np.dtype(sp_matrix)
    ncols = sp_matrix.shape[1]

    if old_format != 'lil':
        sp_matrix = sp_matrix.tolil()

    _, ind = np.unique(sp_matrix.data + sp_matrix.rows, return_index=True)
    rows = sp_matrix.rows[ind]
    data = sp_matrix.data[ind]
    nrows_uniq = data.shape[0]

    sp_matrix = sp.lil_matrix((nrows_uniq, ncols), dtype=dt)  #  or sp_matrix.resize(nrows_uniq, ncols)
    sp_matrix.data = data
    sp_matrix.rows = rows

    ret = sp_matrix.asformat(old_format)
    if axis == 1:
        ret = ret.T        
    return ret


def lexsort_row(A):
    ''' numpy lexsort of the rows, not used in sp_unique'''
    return A[np.lexsort(A.T[::-1])]

if __name__ == '__main__':    
    # Test
    # Create a large sparse matrix with elements in [0, 10]
    A = 10*sp.random(10000, 3, 0.5, format='csr')
    A = np.ceil(A).astype(int)

    # unique rows
    A_uniq = sp_unique(A, axis=0).toarray()
    A_uniq = lexsort_row(A_uniq)
    A_uniq_numpy = np.unique(A.toarray(), axis=0)
    assert (A_uniq == A_uniq_numpy).all()

    # unique columns
    A_uniq = sp_unique(A, axis=1).toarray()
    A_uniq = lexsort_row(A_uniq.T).T
    A_uniq_numpy = np.unique(A.toarray(), axis=1)
    assert (A_uniq == A_uniq_numpy).all()  
like image 123
Andreas K. Avatar answered Oct 31 '22 11:10

Andreas K.


One could also use slicing

def remove_duplicate_rows(data):
    unique_row_indices, unique_columns = [], []
    for row_idx, row in enumerate(data):
        indices = row.indices.tolist()
        if indices not in unique_columns:
            unique_columns.append(indices)
            unique_row_indices.append(row_idx)
    return data[unique_row_indices]

I found this especially helpful when I was in a supervised machine-learning setting. There, the input to my function was data and labels. With this approach, I could easily return

labels[unique_row_indices]

aswell to make sure data and labels are on-par after this clean-up.

like image 44
AlexConfused Avatar answered Oct 31 '22 13:10

AlexConfused