Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python pandas - filter rows after groupby

For example, I have the following table:

index,A,B
0,0,0
1,0,8
2,0,8
3,1,5
4,1,3

After grouping by A:

0:
index,A,B
0,0,0
1,0,8
2,0,8

1:
index,A,B
3,1,5
4,1,3

What I need is to drop rows from each group, where the number in column B is less than maximum value from all rows from group's column B. Well I have a problem translating and formulating this problem to English so here is the example:

Maximum value from rows in column B in group 0: 8

So I want to drop row with index 0 and keep rows with indexes 1 and 2

Maximum value from rows in column B in group 1: 5

So I want to drop row with index 4 and keep row with index 3

I have tried to use pandas filter function, but the problem is that it is operating on all rows in group at one time:

data = <example table>
grouped = data.groupby("A")
filtered = grouped.filter(lambda x: x["B"] == x["B"].max())

So what I ideally need is some filter, which iterates through all rows in group.

Thanks for help!

P.S. Is there also way to only delete rows in groups and do not return DataFrame object?

like image 405
jirinovo Avatar asked Dec 15 '14 15:12

jirinovo


3 Answers

You just need to use apply on the groupby object. I modified your example data to make this a little more clear:

import pandas
from io import StringIO

csv = StringIO("""index,A,B
0,1,0.0
1,1,3.0
2,1,6.0
3,2,0.0
4,2,5.0
5,2,7.0""")

df = pandas.read_csv(csv, index_col='index')
groups = df.groupby(by=['A'])
print(groups.apply(lambda g: g[g['B'] == g['B'].max()]))

Which prints:

         A  B
A index      
1 2      1  6
2 4      2  7
like image 52
Paul H Avatar answered Nov 14 '22 08:11

Paul H


EDIT: I just learned a much neater way to do this using the .transform group by method:

def get_max_rows(df):
    B_maxes = df.groupby('A').B.transform(max)
    return df[df.B == B_maxes] 

B_maxes is a series which identically indexed as the original df containing the maximum value of B for each A group. You can pass lots of functions to the transform method. I think once they have output either as a scalar or vector of the same length. You can even pass some strings as common function names like 'median'. This is slightly different to Paul H's method in that 'A' won't be an index in the result, but you can easily set that after.

import numpy as np
import pandas as pd
df_lots_groups = pd.DataFrame(np.random.rand(30000, 3), columns = list('BCD')
df_lots_groups['A'] = np.random.choice(range(10000), 30000)

%timeit get_max_rows(df_lots_groups)
100 loops, best of 3: 2.86 ms per loop

%timeit df_lots_groups.groupby('A').apply(lambda df: df[ df.B == df.B.max()])
1 loops, best of 3: 5.83 s per loop

EDIT:

Here's a abstraction which allows you to select rows from groups using any valid comparison operator and any valid groupby method:

def get_group_rows(df, group_col, condition_col, func=max, comparison='=='):
    g = df.groupby(group_col)[condition_col]
    condition_limit = g.transform(func)
    df.query('condition_col {} @condition_limit'.format(comparison))

So, for example, if you want all rows in above the median B-value in each A-group you call

get_group_rows(df, 'A', 'B', 'median', '>')

A few examples:

%timeit get_group_rows(df_lots_small_groups, 'A', 'B', 'max', '==')
100 loops, best of 3: 2.84 ms per loop
%timeit get_group_rows(df_lots_small_groups, 'A', 'B', 'mean', '!=')
100 loops, best of 3: 2.97 ms per loop
like image 42
JoeCondron Avatar answered Nov 14 '22 06:11

JoeCondron


Here's the other example for : Filtering the rows with maximum value after groupby operation using idxmax() and .loc()

In [465]: import pandas as pd

In [466]:   df = pd.DataFrame({
               'sp' : ['MM1', 'MM1', 'MM1', 'MM2', 'MM2', 'MM2'],
               'mt' : ['S1', 'S1', 'S3', 'S3', 'S4', 'S4'], 
               'value' : [3,2,5,8,10,1]     
                })

In [467]: df
Out[467]: 
   mt   sp  value
0  S1  MM1      3
1  S1  MM1      2
2  S3  MM1      5
3  S3  MM2      8
4  S4  MM2     10
5  S4  MM2      1

### Here, idxmax() finds the indices of the rows with max value within groups,
### and .loc() filters the rows using those indices :
In [468]: df.loc[df.groupby(["mt"])["value"].idxmax()]                                                                                                                           
Out[468]: 
   mt   sp  value
0  S1  MM1      3
3  S3  MM2      8
4  S4  MM2     10
like image 11
Surya Avatar answered Nov 14 '22 07:11

Surya