Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to slice a pandas MultiIndex df keeping all values until a certain condition is met?

Tags:

I have a 3-level MultiIndex dataframe and I would like to slice it such that all the values until a certain condition is met are kept. To make an example, I have the following dataframe:

                           Col1  Col2
Date          Range  Label
'2018-08-01'  1      A     900   815
                     B     850   820
                     C     800   820
                     D     950   840
              2      A     900   820
                     B     750   850
                     C     850   820
                     D     850   800

And I would like to select all the values until Col1 becomes smaller than Col2. As soon as I have an instance for which Col1 < Col2 then I don't care about the data anymore and I would like to remove them (even if Col1 becomes larger than Col2 again). Considering the example above this is the dataframe that I would like to obtain:

                           Col1  Col2
Date          Range  Label
'2018-08-01'  1      A     900   815
                     B     850   820
              2      A     900   820

I tried several options but I haven't find a good solution yet. I can easily keep all the data for which Col1 > Col2 with:

df_new=df[df['Col1']>df['Col2']]

but this is not what I need. I've also been thinking about cycling through the level 1 index and slice the dataframe with pd.IndexSlice:

idx = pd.IndexSlice
idx_lev1=df.index.get_level_values(1).unique()

for j in (idx_lev1):
    df_lev1=df.loc[idx[:,j,:],:]
    idxs=df_lev1.index.get_level_values(2)[np.where(df_lev1['Col1']<df_lev1['Col2'])[0][0]-1]
    df_sliced= df_lev1.loc[idx[:,:,:idxs],:]

and then concatenate the various dataframes afterwards. However, this is not efficient (my dataframe has more than 3 millions entries so this is also something I have to consider) and I have the problem that the range index is repeated for different dates, so I would probably have to nest 2 cycles or something similar.

I'm sure there must be a simple and more pythonic solution but I wasn't able to find a way around this.

In case you want to generate the dataframe above for testing you could use:

from io import StringIO
s="""                         
Date  Range  Label  Col1  Col2
'2018-08-01'  1  A  900   815
'2018-08-01'  1  B  850   820
'2018-08-01'  1  C  800   820
'2018-08-01'  1  D  950   840
'2018-08-01'  2  A  900   820
'2018-08-01'  2  B  750   850
'2018-08-01'  2  C  850   820
'2018-08-01'  2  D  850   800
"""
df2 = pd.read_csv(StringIO(s),
             sep='\s+',
             index_col=['Date','Range','Label'])

Update:

I tried to implement both the solution from Adam.Er8 and from Alexandre B. they work fine with the test dataframe I created for SO but not with the real data.
The problem is that there may be instances for which Col1 values are always larger than Col2 and in this case I would just like to keep all the data. None of the solutions proposed so far can really deal with this problem.

For a more realistic test case you can use this example:

s="""                         
Date  Range  Label  Col1  Col2
'2018-08-01'  1  1  900   815
'2018-08-01'  1  2  950   820
'2018-08-01'  1  3  900   820
'2018-08-01'  1  4  950   840
'2018-08-01'  2  1  900   820
'2018-08-01'  2  2  750   850
'2018-08-01'  2  3  850   820
'2018-08-01'  2  4  850   800
'2018-08-02'  1  1  900   815
'2018-08-02'  1  2  850   820
'2018-08-02'  1  3  800   820
'2018-08-02'  1  4  950   840
'2018-08-02'  2  1  900   820
'2018-08-02'  2  2  750   850
'2018-08-02'  2  3  850   820
'2018-08-02'  2  4  850   800
"""

Alternatively, you can download a hdf file from here. This is a subset of the dataframe I'm really using.

like image 486
baccandr Avatar asked Jul 04 '19 12:07

baccandr


1 Answers

I tried to use .cumcount() to number each row, then find the first row that has the right condition, and use it to filter only rows that have a number lower than that.

try this:

from collections import defaultdict

import pandas as pd
from io import StringIO

s="""
Date  Range  Label  Col1  Col2
'2018-08-01'  1  1  900   815
'2018-08-01'  1  2  950   820
'2018-08-01'  1  3  900   820
'2018-08-01'  1  4  950   840
'2018-08-01'  2  1  900   820
'2018-08-01'  2  2  750   850
'2018-08-01'  2  3  850   820
'2018-08-01'  2  4  850   800
'2018-08-02'  1  1  900   815
'2018-08-02'  1  2  850   820
'2018-08-02'  1  3  800   820
'2018-08-02'  1  4  950   840
'2018-08-02'  2  1  900   820
'2018-08-02'  2  2  750   850
'2018-08-02'  2  3  850   820
'2018-08-02'  2  4  850   800
"""
df = pd.read_csv(StringIO(s),
                 sep='\s+',
                 index_col=['Date', 'Range', 'Label'])

groupby_date_range = df.groupby(["Date", "Range"])
df["cumcount"] = groupby_date_range.cumcount()

first_col1_lt_col2 = defaultdict(lambda: len(df), df[df['Col1'] < df['Col2']].groupby(["Date", "Range"])["cumcount"].min().to_dict())

result = df[df.apply(lambda row: row["cumcount"] < first_col1_lt_col2[row.name[:2]], axis=1)].drop(columns="cumcount")
print(result)

Output:

                          Col1  Col2
Date         Range Label            
'2018-08-01' 1     1       900   815
                   2       950   820
                   3       900   820
                   4       950   840
             2     1       900   820
'2018-08-02' 1     1       900   815
                   2       850   820
             2     1       900   820
like image 124
Adam.Er8 Avatar answered Oct 19 '22 03:10

Adam.Er8