Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pandas, Avoiding hierarchy in pivot table

Tags:

python

pandas

I have a pandas data frame, df, from which a pivot table is generated using the following function;

def objective2(excel_file):
    df = pd.read_excel(excel_file)

    # WBC cut-offs
    df['WBC_groups'] = pd.cut(df.WBC, [0, 4, 12, 100], 
                             labels=['WBC < 4', 'WBC Normal', 'WBC > 12'])

    df['count'] = 1

    table = df.pivot_table('count', index=['Sex'],
                           columns=['WBC_groups', 'Outcome_at_24'],
                           aggfunc='sum',
                           margins=True, margins_name='Total')

    return table

This generate the following table:

WBC_groups         WBC < 4      WBC Normal      WBC > 12      Total
Outcome_at_24   Alive Died      Alive Died    Alive Died       
Sex                                                            
Female           10.0  2.0       20.0  6.0     14.0  NaN       86.0
Male              3.0  NaN       28.0  3.0     26.0  4.0      111.0
Total            13.0  2.0       48.0  9.0     40.0  4.0      197.0

How can I avoid the hierarchy in the columns so that the table looks like this:

WBC_groups       WBC < 4    WBC Normal   WBC > 12   Alive   Died  Total      
Sex                                                            
Female           10.0          2.0       20.0       6.0     14.0  86.0
Male              3.0          NaN       28.0       3.0     26.0  111.0
Total            13.0          2.0       48.0       9.0     40.0  197.0

Note: data in the tables are not accurate, just dummies.

like image 827
Amani Avatar asked Jan 22 '26 06:01

Amani


1 Answers

I think you cannot avoiding hierarchy, because in pivot_table use parameter columns with two columns - WBC_groups and Outcome_at_24.

The easiest solution is set new column names and then drop column rem:

df.columns = ['WBC < 4', 'WBC Normal', 'WBC > 12', 'Alive', 'Died', 'rem', 'Total']
df = df.drop('rem', axis=1)
print df
        WBC < 4  WBC Normal  WBC > 12  Alive  Died  Total
Sex                                                      
Female     10.0         2.0      20.0    6.0  14.0   86.0
Male        3.0         NaN      28.0    3.0  26.0  111.0
Total      13.0         2.0      48.0    9.0  40.0  197.0

But if you need more general solution:

print df
WBC_groups    WBC < 4      WBC Normal      WBC > 12       Total
Outcome_at_24   Alive Died      Alive Died    Alive Died       
Sex                                                            
Female           10.0  2.0       20.0  6.0     14.0  NaN   86.0
Male              3.0  NaN       28.0  3.0     26.0  4.0  111.0
Total            13.0  2.0       48.0  9.0     40.0  4.0  197.0

cols1 = df.columns.get_level_values('WBC_groups').to_series().drop_duplicates().tolist()
print cols1
['WBC < 4', 'WBC Normal', 'WBC > 12', 'Total']

cols2 = df.columns.get_level_values('Outcome_at_24').to_series().drop_duplicates().tolist()
print cols2
['Alive', 'Died', ' ']

cols = cols1[:-1] + cols2[:2] + ['rem'] + cols1[-1:]
print cols
['WBC < 4', 'WBC Normal', 'WBC > 12', 'Alive', 'Died', 'rem', 'Total']

df.columns = cols

df = df.drop('rem', axis=1)
print df
        WBC < 4  WBC Normal  WBC > 12  Alive  Died  Total
Sex                                                      
Female     10.0         2.0      20.0    6.0  14.0   86.0
Male        3.0         NaN      28.0    3.0  26.0  111.0
Total      13.0         2.0      48.0    9.0  40.0  197.0
like image 129
jezrael Avatar answered Jan 23 '26 19:01

jezrael



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!