Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sum matrix elements group by indices in Python

I have two matrix (same row and column): one with float values, which are grouped by indices in the other matrix. As a result, I want a dictionary or a list with the sums of the elements for each index. Indices always start at 0.

A = np.array([[0.52,0.25,-0.45,0.13],[-0.14,-0.41,0.31,-0.41]])
B = np.array([[1,3,1,2],[3,0,2,2]])

RESULT = {0: -0.41, 1: 0.07, 2: 0.03, 3: 0.11}

I found this solution, but I'm searching for a faster one. I'm working with matrix with 784 x 300 cells and this algorithm takes ~28ms to complete.

import numpy as np

def matrix_sum_by_indices(indices,matrix):
    a = np.hstack(indices)
    b = np.hstack(matrix)
    sidx = a.argsort()
    split_idx = np.flatnonzero(np.diff(a[sidx])>0)+1
    out = np.split(b[sidx], split_idx)
    return [sum(x) for x in out]

If you can help me find a better and plain solution to this problem, I'll be grateful!

EDIT: I made a mistake, time to complete is ~8ms in a 300*10 matrix, but ~28ms in a 784x300.

EDIT2: My A elements are float64, so bincount give me ValueError.

like image 210
Mortafix Avatar asked Jul 10 '18 14:07

Mortafix


1 Answers

You can make use of bincount here:

a = np.array([[0.52,0.25,-0.45,0.13],[-0.14,-0.41,0.31,-0.41]])
b = np.array([[1,3,1,2],[3,0,2,2]])

N = b.max() + 1
id = b + (N*np.arange(b.shape[0]))[:, None] # since you can't apply bincount to a 2D array
np.sum(np.bincount(id.ravel(), a.ravel()).reshape(a.shape[0], -1), axis=0)

Output:

array([-0.41,  0.07,  0.03,  0.11])

As a function:

def using_bincount(indices, matrx):
    N = indices.max() + 1
    id = indices + (N*np.arange(indices.shape[0]))[:, None] # since you can't apply bincount to a 2D array
    return np.sum(np.bincount(id.ravel(), matrx.ravel()).reshape(matrx.shape[0], -1), axis=0)

Timings on this sample:

In [5]: %timeit using_bincount(b, a)
31.1 µs ± 1.74 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [6]: %timeit matrix_sum_by_indices(b, a)
61.3 µs ± 2.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [88]: %timeit scipy.ndimage.sum(a, b, index=[0,1,2,3])
54 µs ± 218 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(scipy.ndimage.sum should be faster on much larger samples)

like image 169
user3483203 Avatar answered Oct 20 '22 00:10

user3483203