Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

dplyr - filter by group size

What is the best way to filter a data.frame to only get groups of say size 5?

So my data looks as follows:

require(dplyr)
n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]

dat <- data.frame(x = x, cat = cat)

The dplyr way i could come up with was

dat <- group_by(dat, cat)

system.time({
  out1 <- dat %>% filter(n() == 5L)
})
#    user  system elapsed 
#   1.157   0.218   1.497

But this is very slow... Is there a better way in dplyr?

So far my workaround solutions looks as follows:

system.time({
  all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
  take_only <- which(group_size(dat) == 5L)
  out2 <- dat[all_ind %in% take_only, ]
})
#    user  system elapsed 
#   0.026   0.008   0.036
all.equal(out1, out2) # TRUE

But this doesn't feel very dplyr like...

like image 499
Rentrop Avatar asked Mar 30 '17 06:03

Rentrop


People also ask

What does Slice_max do in R?

slice_sample() randomly selects rows. slice_min() and slice_max() select rows with highest or lowest values of a variable.

How does group_by work in R?

The group_by() function in R is from dplyr package that is used to group rows by column values in the DataFrame, It is similar to GROUP BY clause in SQL. R dplyr groupby is used to collect identical data into groups on DataFrame and perform aggregate functions on the grouped data.

How do I select rows by row numbers in R?

Use the square bracket operator with df[] notation to select rows by index in R, The syntax of this notation is df[rows, columns], replace rows with the index number, index range, or list of index values.

How do I group columns in R?

Group By Multiple Columns in R using dplyrUse group_by() function in R to group the rows in DataFrame by multiple columns (two or more), to use this function, you have to install dplyr first using install. packages('dplyr') and load it using library(dplyr) . All functions in dplyr package take data.


2 Answers

You can do it more concisely with n():

library(dplyr)
dat %>% group_by(cat) %>% filter(n() == 5)
like image 141
Joe Avatar answered Sep 29 '22 13:09

Joe


I know you asked for a dplyr solution but if you combine it with some purrr you can get it in one line without specifying any new functions. (A little slower though.)

library(dplyr)
library(purrr)
library(tidyr)

dat %>% 
  group_by(cat) %>% 
  nest() %>% 
  mutate(n = map(data, n_distinct)) %>%
  unnest(n = n) %>% 
  filter(n == 5) %>% 
  select(cat, n)
like image 38
ceefel Avatar answered Sep 29 '22 13:09

ceefel