Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cumulative counts in NumPy without iteration

Tags:

python

numpy

I have an array like so:

a = np.array([0.1, 0.2, 1.0, 1.0, 1.0, 0.9, 0.6, 1.0, 0.0, 1.0])

I'd like to have a running counter of instances of 1.0 that resets when it encounters a 0.0, so the result would be:

[0, 0, 1, 2, 3, 3, 3, 4, 0, 1]

My initial thought was to use something like b = np.cumsum(a[a==1.0]), but I don't know how to (1) modify this to reset at zeros or (2) quite how to structure it so the output array is the same shape as the input array. Any ideas how to do this without iteration?

like image 731
triphook Avatar asked Dec 01 '15 18:12

triphook


1 Answers

I think you could do something like

def rcount(a):
    without_reset = (a == 1).cumsum()
    reset_at = (a == 0)
    overcount = np.maximum.accumulate(without_reset * reset_at)
    result = without_reset - overcount
    return result

which gives me

>>> a = np.array([0.1, 0.2, 1.0, 1.0, 1.0, 0.9, 0.6, 1.0, 0.0, 1.0])
>>> rcount(a)
array([0, 0, 1, 2, 3, 3, 3, 4, 0, 1])

This works because we can use the cumulative maximum to figure out the "overcount":

>>> without_reset * reset_at
array([0, 0, 0, 0, 0, 0, 0, 0, 4, 0])
>>> np.maximum.accumulate(without_reset * reset_at)
array([0, 0, 0, 0, 0, 0, 0, 0, 4, 4])

Sanity testing:

def manual(arr):
    out = []
    count = 0
    for x in arr:
        if x == 1:
            count += 1
        if x == 0:
            count = 0
        out.append(count)
    return out

def test():
    for w in [1, 2, 10, 10**4]:
        for trial in range(100):
            for vals in [0,1],[0,1,2]:
                b = np.random.choice(vals, size=w)
                assert (rcount(b) == manual(b)).all()
    print("hooray!")

and then

>>> test()
hooray!
like image 153
DSM Avatar answered Sep 20 '22 10:09

DSM