Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sample n rows from each group in pandas

Tags:

python

pandas

In the dataframe below, I am doing groupby on three fields: 'Subject', 'Rep' and 'yval'.

import pandas as pd 
yval = [[1]*30 + [2]*20 + [1]*20 + [2]*30 ]
yval = reduce(lambda x,y: x+y, yval)
df = pd.DataFrame({'yval': yval , 'xval':np.random.randn(100)})
df['Subject'] = ['S01'] * 50 + ['S02'] * 50
l = [[x] * 10 for x in range(3)] + [[x] * 10 for x in range(2)] + [[x] * 10 for x in range(2)] + [[x] * 10 for x in range(3)]
l = reduce(lambda x,y: x+y,l)
df['Rep'] = l
df


for k, t in df.groupby(['Subject', 'yval', 'Rep']):
    print k 


('S01', 1, 0)
('S01', 1, 1)
('S01', 1, 2)
('S01', 2, 0)
('S01', 2, 1)
('S02', 1, 0)
('S02', 1, 1)
('S02', 2, 0)
('S02', 2, 1)
('S02', 2, 2)

I am trying to find a way to select n rows from group each. In this example, assuming n = 2, we might get the following result. If n=4, I expect everything (the entire dataframe).

('S01', 1, 0)
('S01', 1, 2)
('S01', 2, 0)
('S01', 2, 1)
('S02', 1, 0)
('S02', 1, 1)
('S02', 2, 1)
('S02', 2, 2)

enter image description here

like image 541
learner Avatar asked Jan 30 '26 01:01

learner


1 Answers

The previous answer selects n groups, whereas OP wants to select n rows from each group. Then it should be done as

ix = np.hstack([np.random.choice(v, n, replace=False) for v in gps.groups.values()])

where gps = df.groupby(['Subject', 'yval', 'Rep']).

Then df.iloc(ix) will give n rows selected randomly from each group.

like image 64
erensezener Avatar answered Jan 31 '26 15:01

erensezener