Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy recursive compare against all data in row

Tags:

python

numpy

I have a booleen numpy array as follows:

bool_arr = array([[ True,  True,  True,  True],
                  [False, False,  True,  True],
                  [False, False, False,  True]], dtype=bool)

I want to compare, along the rows, returning True only for the first instance of True, otherwise return False. So the correct answer to the comparison I'd like to bool_arr would be:

ans = array([[ True,  False,  False,  False],
             [False, False,  True,  False],
             [False, False, False,  True]], dtype=bool)

I can't figure out the "Numpy" way to do this without loops.

like image 352
johntfoster Avatar asked May 25 '26 01:05

johntfoster


1 Answers

You can use argmax along axis=1 to find the index of the first True value in each row, and then construct a new array and assign True in those indices:

res = zeros_like(bool_arr)
res[arange(res.shape[0]),argmax(bool_arr, axis=1)] = True

This relies on the fact argmax returns the index of the first of the maximal values.


EDIT: @unutbu pointed out this doesn't work for rows which contain no True values. To make this work, we can mask the indices:

m = bool_arr.any(axis=1)
res = zeros_like(bool_arr)
res[ arange(res.shape[0])[m] , argmax(bool_arr, axis=1)[m] ] = True
like image 98
shx2 Avatar answered May 27 '26 15:05

shx2



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!