Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter a GroupBy object where at least 1 row fulfills the condition

Let's say there's this test_df:

test_df = pd.DataFrame({'Category': ['P', 'P', 'P', 'Q', 'Q', "Q"],
                    'Subcategory' : ['A', 'B', 'C', 'C', 'A', 'B'],
                    'Value' : [2.0, 5., 8., 1., 2., 1.]})

Doing this gives:

test_df.groupby(['Category', 'Subcategory'])['Value'].sum()
# Output is this
Category  Subcategory
P         A              2.0
          B              5.0
          C              8.0
Q         A              2.0
          B              1.0
          C              1.0

I want to filter for Categories where at least one value in the Subcategory is more than or equal to 3. Meaning in the current test_df, Q will be excluded from the filter as none of its rows are greater than or equal to 3. If one of its rows is 5, however, then Q will remain in the filter.

I have tried using the following, but it filters out the 'A' Subcategory in Category 'P'.

test_df_grouped = test_df.groupby(['Category', 'Subcategory'])

test_df_grouped.filter(lambda x: (x['Value'] > 2).any()).groupby(['Category', 'Subcategory'])['Value'].sum()

Thank you in advance!

like image 609
johnconnor92 Avatar asked Aug 24 '18 05:08

johnconnor92


2 Answers

Using loc:

s = test_df.groupby(['Category', 'Subcategory'])['Value'].sum()
s.loc[s[s.ge(3)].index.get_level_values(0).unique()].reset_index()

  Category Subcategory  Value
0        P           A    2.0
1        P           B    5.0
2        P           C    8.0
like image 79
user3483203 Avatar answered Oct 14 '22 18:10

user3483203


Use:

mask = (test_df['Category'].isin(test_df.loc[test_df['Value'] >= 3, 'Category'].unique())
a = test_df[mask]
print (a)
  Category Subcategory  Value
0        P           A    2.0
1        P           B    5.0
2        P           C    8.0

First get all Category values by condition:

print (test_df.loc[test_df['Value'] >= 3, 'Category'])
1    P
2    P
Name: Category, dtype: object

And for beter performance create unique values, thanks @Sandeep Kadapa:

print (test_df.loc[test_df['Value'] >= 3, 'Category'].unique())
['P']

And then filter original column by isin:

print (test_df['Category'].isin(test_df.loc[test_df['Value'] >= 3, 'Category'].unique()))
0     True
1     True
2     True
3    False
4    False
5    False
Name: Category, dtype: bool

Same solution for filtring Series with MultiIndex after groupby:

s = test_df.groupby(['Category', 'Subcategory'])['Value'].sum()
print (s)
Category  Subcategory
P         A              2.0
          B              5.0
          C              8.0
Q         A              2.0
          B              1.0
          C              1.0
Name: Value, dtype: float64

idx0 = s.index.get_level_values(0)
a = s[idx0.isin(idx0[s >= 3].unique())]
print (a)
Category  Subcategory
P         A              2.0
          B              5.0
          C              8.0
Name: Value, dtype: float64
like image 30
jezrael Avatar answered Oct 14 '22 18:10

jezrael