I have an array myA
like this:
array([ 7, 4, 5, 8, 3, 10])
If I want to replace all values that are larger than a value val
by 0, I can simply do:
myA[myA > val] = 0
which gives me the desired output (for val = 5
):
array([0, 4, 5, 0, 3, 0])
However, my goal is to replace not all but only the first n
elements of this array that are larger than a value val
.
So, if n = 2
my desired outcome would look like this (10
is the third element and should therefore not been replaced):
array([ 0, 4, 5, 0, 3, 10])
A straightforward implementation would be:
import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5
# track the number of replacements
repl = 0
for ind, vali in enumerate(myA):
if vali > val:
myA[ind] = 0
repl += 1
if repl == n:
break
That works but maybe someone can can up with a smart way of masking!?
The following should work:
myA[(myA > val).nonzero()[0][:2]] = 0
since nonzero will return the indexes where the boolean array myA > val
is non zero e.g. True
.
For example:
In [1]: myA = array([ 7, 4, 5, 8, 3, 10])
In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0
In [3]: myA
Out[3]: array([ 0, 4, 5, 0, 3, 10])
Final solution is very simple:
import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5
myA[np.where(myA > val)[0][:n]] = 0
print(myA)
Output:
[ 0 4 5 0 3 10]
Here's another possibility (untested), probably no better than nonzero
:
def truncate_mask(m, stop):
m = m.astype(bool, copy=False) # if we allow non-bool m, the next line becomes nonsense
return m & (np.cumsum(m) <= stop)
myA[truncate_mask(myA > val, n)] = 0
By avoiding building and using an explicit index you might end up with slightly better performance...but you'd have to test it to find out.
Edit 1: while we're on the subject of possibilities, you could also try:
def truncate_mask(m, stop):
m = m.astype(bool, copy=True) # note we need to copy m here to safely modify it
m[np.searchsorted(np.cumsum(m), stop):] = 0
return m
Edit 2 (the next day): I've just tested this and it seems that cumsum
is actually worse than nonzero
, at least with the kinds of values I was using (so neither of the above approaches is worth using). Out of curiosity, I also tried it with numba:
import numba
@numba.jit
def set_first_n_gt_thresh(a, val, thresh, n):
ii = 0
while n>0 and ii < len(a):
if a[ii] > thresh:
a[ii] = val
n -= 1
ii += 1
This only iterates over the array once, or rather it only iterates over the necessary part of the array once, never even touching the latter part. This gives you vastly superior performance for small n
, but even for the worst case of n>=len(a)
this approach is faster.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With