Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

k-fold cross-validation in dplyr?

Tags:

r

dplyr

Hadley Wickham proposed that it could be possible to conduct bootstrap using dplyr package, his suggestion was improved and then implemented in broom package. Would it be also possible to implement k-fold cross-validation?

I guess that the first step (selecting the train group) is very simple:

crossvalidate <- function (df, k = 5) {
  n <- nrow(df)
  idx <- sample(rep_len(1:k, n))
  attr(df, "indices") <- lapply(1:k, function(i) which(idx != i))
  attr(df, "drop") <- TRUE
  attr(df, "group_sizes") <- nrow(df) - unclass(table(idx))
  attr(df, "biggest_group_size") <- max(attr(df, "group_sizes"))
  attr(df, "labels") <- data.frame(replicate = 1:k)
  attr(df, "vars") <- list(quote(replicate))
  class(df) <- c("grouped_df", "tbl_df", "tbl", "data.frame")
  df
}

But somehow I cannot find anywhere any documentation of attr(, "indices") to learn if it is somehow possible to use indices "other that" the ones selected for selecting the test group indices. Do you have any ideas?

like image 293
Tim Avatar asked Oct 17 '22 22:10

Tim


2 Answers

https://rpubs.com/dgrtwo/cv-modelr - there you have an example of k-fold cross-validation using dplyr package:

library(ISLR)
library(dplyr)
library(purrr)
library(modelr)
library(broom)
library(tidyr)

set.seed(1)

models <- Smarket %>%
  select(Today, Lag1:Lag5) %>%
  crossv_kfold(k = 20) %>%
  mutate(model = map(train, ~ lm(Today ~ ., data = .)))

predictions <- models %>%
  unnest(map2(model, test, ~ augment(.x, newdata = .y)))

predictions %>%
  summarize(MSE = mean((Today - .fitted) ^ 2),
            MSEIntercept = mean((Today - mean(Today))^2))
like image 101
Jot eN Avatar answered Oct 21 '22 09:10

Jot eN


Here is one solution for stratified 5-fold CV using dplyr:

df_fold = df %>%
  group_by(group_var) %>%
  sample_frac(1) %>%
  mutate(fold=rep(1:5, length.out=n())) %>%
  ungroup

for(i in 1:5){
  val = df_fold %>% filter(fold==i)
  tr = df_fold %>% anti_join(val, by=ID_var)
}
like image 33
user2236600 Avatar answered Oct 21 '22 08:10

user2236600