Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dask: Getting the Row which has the max value in groups using groupby

The same problem can be solved in Pandas using transform as explained here With dask the only working solution I found use merge. And I was wondering if there are other ways to achieve it.

like image 675
rpanai Avatar asked Oct 29 '22 05:10

rpanai


1 Answers

First, I want to rewrite the referenced script in your original question to make sure I have understood its intent. As far as I can tell, as illustrated in my below rewrite, you essentially desire a way to extract the values who have the heighest count cnt value for each unique pairing of foo and bar. Below is roughly how the referenced script accomplished that with just Pandas.

# create an example dataframe
df = pd.DataFrame({
        'foo' : ['MM1', 'MM1', 'MM1', 'MM2', 'MM2', 'MM2', 'MM4', 'MM4', 'MM4'],
        'bar' : ['S1', 'S1', 'S3', 'S3', 'S4', 'S4', 'S2', 'S2', 'S2'],
        'cnt' : [3, 2, 5, 8, 10, 1, 2, 2, 7],
        'val' : ['a', 'n', 'cb', 'mk', 'bg', 'dgb', 'rd', 'cb', 'uyi'],
    })


grouped_df = (df.groupby(['foo', 'bar'])            # creates a double nested indices
                .agg({'cnt': 'max'})                # returns max value from each grouping
                .rename(columns={'cnt': 'cnt_max'}) # renames the col to avoid conflicts on merge later
                .reset_index())                     # makes the double nested indices columns instead

merged_df = pd.merge(df, grouped_df, how='left', on=['foo', 'bar'])

# note: I believe a shortcoming here is that if ther eis more than one match, this would 
# return multiple results for some pairings...
final_df = merged_df[merged_df['cnt'] == merged_df['cnt_max']]

Now, here is my take on a Dask-ready version, below. See comments for elaboration.

# create an example dataframe
df = pd.DataFrame({
        'foo' : ['MM1', 'MM1', 'MM1', 'MM2', 'MM2', 'MM2', 'MM4', 'MM4', 'MM4'],
        'bar' : ['S1', 'S1', 'S3', 'S3', 'S4', 'S4', 'S2', 'S2', 'S2'],
        'cnt' : [3, 2, 5, 8, 10, 1, 2, 2, 7],
        'val' : ['a', 'n', 'cb', 'mk', 'bg', 'dgb', 'rd', 'cb', 'uyi'],
    })

# I'm not sure if we can rely on val to be a col of unique values so I am just going to 
# make a new column that is the id for this, now on a very large dataframe that wouldn't 
# fit in memory this may not be a reasonable method of creating a unique new column but 
# for the purposes of this example this will be sufficient
df['id'] = np.arange(len(df))

# now let's convert this dataframe into a Dask dataframe
# we will only use 1 partition because this is a small sample and would use more in a real world case
ddf = dd.from_pandas(df, npartitions=1)

# create a function that take the results of the grouped by sub dataframes and returns the row
# where the cnt is greatest
def select_max(grouped_df):
    row_with_max_cnt_index = grouped_df['cnt'].argmax()
    row_with_max_cnt = grouped_df.loc[row_with_max_cnt_index]
    return row_with_max_cnt['id']

# now chain that function into an apply run on the output of the groupby operation
# note: this also may not be the best strategy if the resulting list is too long
# if that is the case, will need to better thread the output of this into the next step
keep_ids = ddf.groupby(['foo', 'bar']).apply(select_max, meta=pd.Series()).compute()

# this is pretty straightforward, just get the rows that match the ids from the max cnt applied method
subset_df = ddf[ddf['id'].isin(keep_ids)]
print(subset_df.compute())
like image 101
kuanb Avatar answered Nov 15 '22 06:11

kuanb