I have a Python pandas DataFrame in which each element is a float or NaN. For each row, I will need to find the column which holds the nth number of the row. That is, I need to get the column holding the nth element of the row that is not NaN. I know that the nth such column always exists.
So if n was 4 and a pandas dataframe called myDF was the following:
10 20 30 40 50 60 70 80 90 100
'A' 4.5 5.5 2.5 NaN NaN 2.9 NaN NaN 1.1 1.8
'B' 4.7 4.1 NaN NaN NaN 2.0 1.2 NaN NaN NaN
'C' NaN NaN NaN NaN NaN 1.9 9.2 NaN 4.4 2.1
'D' 1.1 2.2 3.5 3.4 4.5 NaN NaN NaN 1.9 5.5
I would want to obtain:
'A' 60
'B' 70
'C' 100
'D' 40
I could do:
import pandas as pd
import math
n = some arbitrary int
for row in myDF.indexes:
num_not_NaN = 0
for c in myDF.columns:
if math.isnan(myDF[c][row]) == False:
num_not_NaN +=1
if num_not_NaN==n:
print row, c
break
I'm sure this is very slow and not very Pythonic. Is there an approach that will be faster if I am dealing with a very large DataFrame and large values of n?
If speed is your goal, it's a good idea to make use of Pandas' vectorised methods whenever you can:
>>> (df.notnull().cumsum(axis=1) == 4).idxmax(axis=1) # replace 4 with any number you like
'A' 60
'B' 70
'C' 100
'D' 40
dtype: object
The other answers are good and are perhaps a little clearer syntactically. In terms of speed, there's not much difference between them for your small example. However, for a slightly larger DataFrame, the vectorised method is already around 60 times faster:
>>> df2 = pd.concat([df]*1000) # 4000 row DataFrame
>>> %timeit df2.apply(lambda row: get_nth(row, n), axis=1)
1 loops, best of 3: 749 ms per loop
>>> %timeit df2.T.apply(lambda x: x.dropna()[n-1:].index[0])
1 loops, best of 3: 673 ms per loop
>>> %timeit (df2.notnull().cumsum(1) == 4).idxmax(axis=1)
100 loops, best of 3: 10.5 ms per loop
You could create a function and then pass it to a lambda
function.
The function filters the series for nulls, and then returns the index value of the n
item (or None if the index length is less than n
).
The lambda
function needs axis=1
to ensure it is applied to each row of the DataFrame.
def get_nth(series, n):
s = series[series.notnull()]
if len(s) >= n:
return s.index[n - 1]
>>> n = 4
>>> df.apply(lambda row: get_nth(row, n), axis=1)
A 60
B 70
C 100
D 40
dtype: object
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With