Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy - How to replace elements based on condition (or matching a pattern)

I have a numpy array, say:

>>> a=np.array([[0,1,2],[4,3,6],[9,5,7],[8,9,8]])
>>> a
array([[0, 1, 2],
   [4, 3, 6],
   [9, 5, 7],
   [8, 9, 8]])

I want to replace the second and third column elements with the minimum of them (row by row), except if one of these 2 elements is < 3. The resulting array should be:

array([[0, 1, 2],# nothing changes since 1 and 2 are <3
   [4, 3, 3], #min(3,6)=3 => 6 changed to 3
   [9, 5, 5], #min(5,7)=5 => 7 changed to 5
   [8, 8, 8]]) #min(9,8)=8 => 9 changed to 8

I know I can use clip, for instance a[:,1:3].clip(2,6,a[:,1:3]), but

1) clip will be applied to all elements, including those <3.

2) I don't know how to set the min and max values of clip to the minimum values of the 2 related elements of each row.

like image 246
Dominique Avatar asked Oct 21 '25 21:10

Dominique


2 Answers

Just use the >= operator to first select what you are interested of:

b = a[:, 1:3]  # select the columns
matching = numpy.all(b >= 3, axis=1)  # find rows with all elements matching
b = b[matching, :]  # select rows

Now you can replace the content with the minimum by e.g.:

# find row minimum and convert to a column vector
b[:, :] = b.min(1, keepdims=True)
like image 198
agrinh Avatar answered Oct 23 '25 11:10

agrinh


We first defined a row_mask, depicting the <3 condition, and then apply a minimum along an axis to find the minimum (for rows in row_mask).

The newaxis part is required for the broadcasting of a 1dim array (of minimums) to the 2-dim target of the assignment.

a=np.array([[0,1,2],[4,3,6],[9,5,7],[8,9,8]])
row_mask = (a[:,0]>=3)
a[row_mask, 1:] = a[row_mask, 1:].min(axis=1)[...,np.newaxis]
a
=> 
array([[0, 1, 2],
       [4, 3, 3],
       [9, 5, 5],
       [8, 8, 8]])
like image 29
shx2 Avatar answered Oct 23 '25 11:10

shx2



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!