Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

collapse a pandas MultiIndex

Suppose I have a DataFrame with MultiIndex columns. How can I collapse the levels to a concatenation of the values so that I only have one level?

Setup

np.random.seed([3, 14])
col = pd.MultiIndex.from_product([list('ABC'), list('DE'), list('FG')])
df = pd.DataFrame(np.random.rand(4, 12) * 10, columns=col).astype(int)

print df

   A           B           C         
   D     E     D     E     D     E   
   F  G  F  G  F  G  F  G  F  G  F  G
0  2  1  1  7  5  9  9  2  7  4  0  3
1  3  7  1  1  5  3  1  4  3  5  6  0
2  2  6  9  9  9  5  7  0  1  2  7  5
3  2  2  8  0  3  9  4  7  0  8  2  5

I want the result to look like this:

   ADF  ADG  AEF  AEG  BDF  BDG  BEF  BEG  CDF  CDG  CEF  CEG
0    2    1    1    7    5    9    9    2    7    4    0    3
1    3    7    1    1    5    3    1    4    3    5    6    0
2    2    6    9    9    9    5    7    0    1    2    7    5
3    2    2    8    0    3    9    4    7    0    8    2    5
like image 620
piRSquared Avatar asked May 07 '16 09:05

piRSquared


People also ask

How do I drop a MultiIndex column in pandas?

By using DataFrame. droplevel() or DataFrame. columns. droplevel() you can drop a level from multi-level column index from pandas DataFrame.

How do you flatten a panda?

The first method to flatten the pandas dataframe is through NumPy python package. There is a function in NumPy that is numpy. flatten() that perform this task. First, you have to convert the dataframe to numpy using the to_numpy() method and then apply the flatten() method.


2 Answers

Solution

I did this

def collapse_columns(df):
    df = df.copy()
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = df.columns.to_series().apply(lambda x: "".join(x))
    return df

I had to check if its a MultiIndex because if it wasn't, I'd split a string and recombine it with what ever separator I chose in the join.

like image 103
piRSquared Avatar answered Sep 18 '22 09:09

piRSquared


you may try this:

In [200]: cols = pd.Series(df.columns.tolist()).apply(pd.Series).sum(axis=1)

In [201]: cols
Out[201]:
0     ADF
1     ADG
2     AEF
3     AEG
4     BDF
5     BDG
6     BEF
7     BEG
8     CDF
9     CDG
10    CEF
11    CEG
dtype: object
like image 25
MaxU - stop WAR against UA Avatar answered Sep 20 '22 09:09

MaxU - stop WAR against UA