Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy.where() with 3 or more conditions

I have a dataframe with multiple columns.

      AC     BC     CC      DC     MyColumn

A

B

C

D

I would like to set a new column "MyColumn" where if BC, CC, and DC are less than AC, you take the max of the three for that row. If only CC and DC are less than AC, you take the max of CC and DC for that row, etc etc. If none of them are less than AC, MyColumn should just take the value from AC.

How would I do this with numpy.where()?

like image 564
mit13plee Avatar asked Mar 31 '14 17:03

mit13plee


People also ask

Can NP Where have multiple conditions?

Python NumPy where() is used to get an array with selected elements from the existing array by checking single or multiple conditions. It returns the indices of the array for with each condition being True.

Can you use and in NP Where?

You can apply multiple conditions with np. where() by enclosing each conditional expression in () and using & or | .

What does .all do in NumPy?

all() in Python. The numpy. all() function tests whether all array elements along the mentioned axis evaluate to True.

Can I nest NP Where?

We can use nested np. where() condition checks ( like we do for CASE THEN condition checking in other languages).


1 Answers

You can use the lt method along with where:

In [11]: df = pd.DataFrame(np.random.randn(5, 4), columns=list('ABCD'))

In [12]: df
Out[12]:
          A         B         C         D
0  1.587878 -2.189620  0.631958 -0.432253
1 -1.636721  0.568846 -0.033618 -0.648406
2  1.567512  1.089788  0.489559  1.673372
3  0.589222 -1.176961 -1.186171  0.249795
4  0.366227  1.830107 -1.074298 -1.882093

Note: you can take max of a subset of columns:

In [13]: df[['B', 'C', 'D']].max(1)
Out[13]:
0    0.631958
1    0.568846
2    1.673372
3    0.249795
4    1.830107
dtype: float64

Look at each column's values to see if they are less than A:

In [14]: lt_A = df.lt(df['A'], axis=0)

In [15]: lt_A
Out[15]:
       A      B      C      D
0  False   True   True   True
1  False  False  False  False
2  False   True   True  False
3  False   True   True   True
4  False  False   True   True

In [15]: lt_A[['B', 'C', 'D']].all(1)
Out[15]:
0     True
1    False
2    False
3     True
4    False
dtype: bool

Now, you can build up your desired result using all:

In [16]: df[['B', 'C', 'D']].max(1).where(lt_A[['B', 'C', 'D']].all(1), 2)
Out[16]:
0    0.631958
1    2.000000
2    2.000000
3    0.249795
4    2.000000
dtype: float64

Rather than 2 you can insert first the Series (in this example it happens to be the same):

In [17]: df[['C', 'D']].max(1).where(lt_A[['C', 'D']].all(1), 2)
Out[17]:
0    0.631958
1    2.000000
2    2.000000
3    0.249795
4   -1.074298
dtype: float64

and then column A:

In [18]: df[['B', 'C', 'D']].max(1).where(lt_A[['B', 'C', 'D']].all(1), df[['C', 'D']].max(1).where(lt_A[['C', 'D']].all(1), df['A']))
Out[18]:
0    0.631958
1   -1.636721
2    1.567512
3    0.249795
4   -1.074298
dtype: float64

Clearly, you should write this as function if you're planning on reusing!

like image 154
Andy Hayden Avatar answered Oct 19 '22 23:10

Andy Hayden