Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Apply a lambda with a shift function in python pandas were some null elements are to be replaced

I am trying to do the following in a dataframe. Change the Value of Column Attrition if the Period is not 1, then multiple the value of the retention column in that row by the attrition value in the row above within the groupby. My attempt is below:

import pandas as pd

data = {'Country': ['DE', 'DE', 'DE', 'US', 'US', 'US', 'FR', 'FR', 'FR'],
    'Week': ['201426', '201426', '201426', '201426', '201425', '201425', '201426', '201426', '201426'],
    'Period': [1, 2, 3, 1, 1, 2, 1, 2, 3],
    'Attrition': [0.5,'' ,'' ,0.85 ,0.865,'' ,0.74 ,'','' ],
    'Retention': [0.95,0.85,0.94,0.85,0.97,0.93,0.97,0.93,0.94]}

df = pd.DataFrame(data, columns= ['Country', 'Week', 'Period', 'Attrition','Retention'])
print df

Gives me this output:

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2                 0.85
2      DE  201426       3                 0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2                 0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2                 0.93
8      FR  201426       3                 0.94

The below:

df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: x.Attrition.shift(1)*x['Retention'] if x.Period != 1 else x.Attrition)

print df

gives me this error:

df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: x.Attrition.shift(1)*x['Retention'] if x.Period != 1 else x.Attrition)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

UPDATE: Complete Compiled Solution

The below is the complete working solution of what i was after which was basically using Primer's answer but adding a while loop to keep running the Lambda function on the dataframe column until there are no more NaNs.

import pandas as pd
import numpy as np

data = {'Country': ['DE', 'DE', 'DE', 'US', 'US', 'US', 'FR', 'FR', 'FR'],
    'Week': ['201426', '201426', '201426', '201426', '201425', '201425', '201426', '201426', '201426'],
    'Period': [1, 2, 3, 1, 1, 2, 1, 2, 3],
    'Attrition': [0.5, '' ,'' ,0.85 ,0.865,'' ,0.74 ,'','' ],
    'Retention': [0.95,0.85,0.94,0.85,0.97,0.93,0.97,0.93,0.94]}

df = pd.DataFrame(data, columns= ['Country', 'Week', 'Period', 'Attrition','Retention'])
print df

OUTPUT: Starting DF

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2                 0.85
2      DE  201426       3                 0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2                 0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2                 0.93
8      FR  201426       3                 0.94

Solution:

#Replaces empty string with NaNs
df['Attrition'] = df['Attrition'].replace('', np.nan)

#Stores a count of the number of null or NaNs in the column.
ContainsNaN = df['Attrition'].isnull().sum()

#run the loop while there are some NaNs in the column.
while ContainsNaN > 0:    
    df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: pd.Series(np.where((x.Period != 1), x.Attrition.shift() * x['Retention'], x.Attrition)))        
    ContainsNaN = df['Attrition'].isnull().sum()

print df

OUTPUT: Result

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2     0.425       0.85
2      DE  201426       3    0.3995       0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2   0.80445       0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2    0.6882       0.93
8      FR  201426       3  0.646908       0.94
like image 330
IcemanBerlin Avatar asked Oct 13 '25 02:10

IcemanBerlin


1 Answers

First of all your Attrition column mixes numeric data with empty strings '', which is generally not a good idea and should be fixed before attempting calculations on this column:

df.loc[df['Attrition'] == '', 'Attrition'] = pd.np.nan
df['Attrition'] = df.Attrition.astype('float')

The Error you get is because your condition in .apply: x.Period != 1 produces a Boolean array:

0    False
1     True
2     True
3    False
4    False
5     True
6    False
7     True
8     True
Name: Period, dtype: bool

Which .apply does not know how to handle, because of its ambiguity (i.e what should be True in this case?).

You might consider numpy.where for this task:

import numpy as np
g = df.groupby(['Country','Week'], as_index=0, group_keys=0)
df['Attrition'] = g.apply(lambda x: pd.Series(np.where((x.Period != 1), x.Attrition.shift() * x['Retention'], x.Attrition)).fillna(method='ffill')).values
df

yielding:

  Country    Week  Period  Attrition  Retention
0      DE  201426       1      0.500       0.95
1      DE  201426       2      0.425       0.85
2      DE  201426       3      0.425       0.94
3      US  201426       1      0.740       0.85
4      US  201425       1      0.688       0.97
5      US  201425       2      0.688       0.93
6      FR  201426       1      0.865       0.97
7      FR  201426       2      0.804       0.93
8      FR  201426       3      0.850       0.94

Note that I have added .fillna method, which fills NaN with last observed value.

like image 78
Primer Avatar answered Oct 14 '25 16:10

Primer