Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split dataframe into grouped chunks

Tags:

python

pandas

I would like to split a dataframe into chunks. I have created a function which is able to split a dataframe into equal size chunks however am unable to figure out how to split by groups.

Each split of dataframe must include all instances of a grouping variable, I'd like flexibility on how many groups could be included (as they are relatively small).

Example dataframe:

A  1
A  2
B  3
C  1
D  9
D  10

Target splits (include at least two groups):

Split 1:

A  1
A  2
B  3

Split 2:

C  1
D  9
D  10

If helpful, my current function looks like the following:

def split_frame(sequence, size=10000):
    return (sequence[position:position + size] for position in range(0, len(sequence), size))

Help appreciated!

like image 417
shbfy Avatar asked Nov 16 '25 19:11

shbfy


2 Answers

Works in Python 2 and 3:

df = pd.DataFrame(data=['a', 'a', 'b', 'c', 'a', 'a', 'b', 'v', 'v', 'f'], columns=['A']) 

def iter_by_group(df, column, num_groups):
    groups = []
    for i, group in df.groupby(column):
        groups.append(group)
        if len(groups) == num_groups:
            yield pd.concat(groups)
            groups = []
    if groups:
        yield pd.concat(groups)

for group in iter_by_group(df, 'A', 2):
    print(group)

A
0  a
1  a
4  a
5  a
2  b
6  b

A
3  c
9  f

A
7  v
8  v
like image 140
Dennis Golomazov Avatar answered Nov 18 '25 10:11

Dennis Golomazov


The answer from Dennis Golomazov was too slow for my dataframes. Storing the groups in a list and returning them with pd.concat() is a performance killer.

Here is a slightly faster version. It enumerates the groups and returns them via their group number.

import pandas as pd

def group_chunks(df, column, chunk_size):
    df["n_group"] = df.groupby(column).ngroup()
    lower_group_index = 0
    upper_group_index = chunk_size - 1
    max_group_index = df["n_group"].max()
    while lower_group_index <= max_group_index:
        yield df.loc[:, df.columns != "n_group"][
            df["n_group"].between(lower_group_index, upper_group_index)
        ]
        lower_group_index = upper_group_index + 1
        upper_group_index = upper_group_index + chunk_size

df = pd.DataFrame(data=['a', 'a', 'b', 'c', 'a', 'a', 'b', 'v', 'v', 'f'], columns=['A']) 
for chunk in group_chunks(df, 'A', 2):
    print(f"{chunk.sort_values(by='A')}\n")

   A
0  a
1  a
4  a
5  a
2  b
6  b

   A
3  c
9  f

   A
7  v
8  v

like image 40
Arigion Avatar answered Nov 18 '25 09:11

Arigion