Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Argmax of each row or column in scipy sparse matrix

scipy.sparse.coo_matrix.max returns the maximum value of each row or column, given an axis. I would like to know not the value, but the index of the maximum value of each row or column. I haven't found a way to make this in an efficient manner yet, so I'll gladly accept any help.

like image 574
Jimmy C Avatar asked Jun 09 '15 20:06

Jimmy C


People also ask

What does SciPy sparse Csr_matrix do?

The function csr_matrix() is used to create a sparse matrix of compressed sparse row format whereas csc_matrix() is used to create a sparse matrix of compressed sparse column format.

How do you find the maximum value in a matrix in python?

With the help of Numpy matrix. max() method, we can get the maximum value from given matrix.

What is nnz in sparse matrix?

A sparse matrix stores "non-zero" elements in several arrays. nnz essentially reports the size of these arrays.

What is the SciPy function which creates a sparse matrix?

Python's SciPy provides tools for creating sparse matrices using multiple data structures, as well as tools for converting a dense matrix to a sparse matrix. The sparse matrix representation outputs the row-column tuple where the matrix contains non-zero values along with those values. 15. 1. import numpy as np.


1 Answers

As others mention there is now built-in argmax() for scipy.sparse matrices. However, I found it to be quite slow for large matrices so I had a look at the source code. The logic is very smart, but it contains a python loop slowing things down. Taking the source code and reducing it to argmax per row for example (while sacrificing all generality, shape checking etc. for simplicity) and decorating it with numba can give some nice speed improvements.

Here's the function:

import numpy as np
from numba import jit


def argmax_row_numba(X):
    return _argmax_row_numba(X.shape[0], X.indptr, X.data, X.indices)

@jit(nopython=True)
def _argmax_row_numba(shape, indptr, data, indices):
    # prep an array to hold the indices
    ret = np.zeros(shape)
    # figure out which lines actually contain data
    nz_lines, = np.diff(indptr).nonzero()
    # loop through the lines
    for i in nz_lines:
        p, q = indptr[i: i + 2]
        line_data = data[p: q]
        line_indices = indices[p: q]
        am = np.argmax(line_data)
        ret[i] = line_indices[am]

    return ret

Generating a matrix for testing:


from scipy.sparse import random
size = 10000
m = random(m=size, n=size, density=0.0001, format="csr")
n_vals = m.data.shape[0]
m.data = np.random.random(size=n_vals).astype("float")


# the original scipy implementation reformatted to return a np.array
maxima1 = np.squeeze(np.array(m.argmax(axis=1)))
# calling the numba version
maxima2 = argmax_row_numba(m)

# Check that the results are the same
print(np.allclose(maxima1, maxima2))
# True

Timing results:

%timeit m.argmax(axis=1)
# 30.1 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit argmax_row_numba(m)
# 211 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
like image 184
L_W Avatar answered Sep 20 '22 05:09

L_W