Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

For each row, what is the fastest way to find the column holding nth element that is not NaN?

I have a Python pandas DataFrame in which each element is a float or NaN. For each row, I will need to find the column which holds the nth number of the row. That is, I need to get the column holding the nth element of the row that is not NaN. I know that the nth such column always exists.

So if n was 4 and a pandas dataframe called myDF was the following:

      10   20   30   40   50   60  70  80  90  100

'A'  4.5  5.5  2.5  NaN  NaN  2.9 NaN NaN 1.1 1.8
'B'  4.7  4.1  NaN  NaN  NaN  2.0 1.2 NaN NaN NaN
'C'  NaN  NaN  NaN  NaN  NaN  1.9 9.2 NaN 4.4 2.1
'D'  1.1  2.2  3.5  3.4  4.5  NaN NaN NaN 1.9 5.5

I would want to obtain:

'A'  60
'B'  70
'C'  100 
'D'  40

I could do:

import pandas as pd
import math

n = some arbitrary int
for row in myDF.indexes:
   num_not_NaN = 0   
   for c in myDF.columns:    
      if math.isnan(myDF[c][row]) == False: 
           num_not_NaN +=1
      if num_not_NaN==n:
           print row, c
           break

I'm sure this is very slow and not very Pythonic. Is there an approach that will be faster if I am dealing with a very large DataFrame and large values of n?

like image 585
A. Arpi Avatar asked Aug 12 '15 01:08

A. Arpi


2 Answers

If speed is your goal, it's a good idea to make use of Pandas' vectorised methods whenever you can:

>>> (df.notnull().cumsum(axis=1) == 4).idxmax(axis=1) # replace 4 with any number you like
'A'     60
'B'     70
'C'    100
'D'     40
dtype: object

The other answers are good and are perhaps a little clearer syntactically. In terms of speed, there's not much difference between them for your small example. However, for a slightly larger DataFrame, the vectorised method is already around 60 times faster:

>>> df2 = pd.concat([df]*1000) # 4000 row DataFrame
>>> %timeit df2.apply(lambda row: get_nth(row, n), axis=1)
1 loops, best of 3: 749 ms per loop

>>> %timeit df2.T.apply(lambda x: x.dropna()[n-1:].index[0])
1 loops, best of 3: 673 ms per loop

>>> %timeit (df2.notnull().cumsum(1) == 4).idxmax(axis=1)
100 loops, best of 3: 10.5 ms per loop
like image 156
Alex Riley Avatar answered Sep 23 '22 06:09

Alex Riley


You could create a function and then pass it to a lambda function.

The function filters the series for nulls, and then returns the index value of the n item (or None if the index length is less than n).

The lambda function needs axis=1 to ensure it is applied to each row of the DataFrame.

def get_nth(series, n):
    s = series[series.notnull()]
    if len(s) >= n:
        return s.index[n - 1]

>>> n = 4
>>> df.apply(lambda row: get_nth(row, n), axis=1)
A     60
B     70
C    100
D     40
dtype: object
like image 42
Alexander Avatar answered Sep 25 '22 06:09

Alexander