Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to replace only the first n elements in a numpy array that are larger than a certain value?

I have an array myA like this:

array([ 7,  4,  5,  8,  3, 10])

If I want to replace all values that are larger than a value val by 0, I can simply do:

myA[myA > val] = 0

which gives me the desired output (for val = 5):

 array([0, 4, 5, 0, 3, 0])

However, my goal is to replace not all but only the first n elements of this array that are larger than a value val.

So, if n = 2 my desired outcome would look like this (10 is the third element and should therefore not been replaced):

array([ 0,  4,  5,  0,  3, 10])

A straightforward implementation would be:

import numpy as np

myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5

# track the number of replacements
repl = 0

for ind, vali in enumerate(myA):

    if vali > val:

        myA[ind] = 0
        repl += 1

        if repl == n:
            break

That works but maybe someone can can up with a smart way of masking!?

like image 450
Cleb Avatar asked Jan 26 '16 14:01

Cleb


3 Answers

The following should work:

myA[(myA > val).nonzero()[0][:2]] = 0

since nonzero will return the indexes where the boolean array myA > val is non zero e.g. True.

For example:

In [1]: myA = array([ 7,  4,  5,  8,  3, 10])

In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0

In [3]: myA
Out[3]: array([ 0,  4,  5,  0,  3, 10])
like image 79
JuniorCompressor Avatar answered Oct 18 '22 17:10

JuniorCompressor


Final solution is very simple:

import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5

myA[np.where(myA > val)[0][:n]] = 0

print(myA)

Output:

[ 0  4  5  0  3 10]
like image 31
George Petrov Avatar answered Oct 18 '22 18:10

George Petrov


Here's another possibility (untested), probably no better than nonzero:

def truncate_mask(m, stop):
  m = m.astype(bool, copy=False) #  if we allow non-bool m, the next line becomes nonsense
  return m & (np.cumsum(m) <= stop)

myA[truncate_mask(myA > val, n)] = 0

By avoiding building and using an explicit index you might end up with slightly better performance...but you'd have to test it to find out.

Edit 1: while we're on the subject of possibilities, you could also try:

def truncate_mask(m, stop):
   m = m.astype(bool, copy=True) #  note we need to copy m here to safely modify it
   m[np.searchsorted(np.cumsum(m), stop):] = 0
   return m

Edit 2 (the next day): I've just tested this and it seems that cumsum is actually worse than nonzero, at least with the kinds of values I was using (so neither of the above approaches is worth using). Out of curiosity, I also tried it with numba:

import numba

@numba.jit
def set_first_n_gt_thresh(a, val, thresh, n):
    ii = 0
    while n>0 and ii < len(a):
        if a[ii] > thresh:
            a[ii] = val
            n -= 1
        ii += 1

This only iterates over the array once, or rather it only iterates over the necessary part of the array once, never even touching the latter part. This gives you vastly superior performance for small n, but even for the worst case of n>=len(a) this approach is faster.

like image 42
dan-man Avatar answered Oct 18 '22 18:10

dan-man