How to obtain cross-validated r-square from linear model in R?

I have a linear model in R.

x <- rnorm(100)
z <- rnorm(100)
y <- rnorm(100, x+z)
mydata <- data.frame(x,y,z)

fit <- lm(y ~ x + z, mydata)

I would like to obtain an estimate of the out of sample r-square. I was thinking of using some form k-fold cross validation.

  • What code in R takes a linear model fit and returns a cross-validated r-square?
  • Or is there some other approach to obtaining cross-validated r-square using R?
Jeromy Anglim Avatar asked Apr 16 '13 05:04

Jeromy Anglim

So what follows is a slight adaptation to the example that @NPR linked to from statsmethods. Essentially I adapted the example to make it a function.


k_fold_rsq <- function(lmfit, ngroup=10) {
    # assumes library(bootstrap)
    # adapted from http://www.statmethods.net/stats/regression.html
    mydata <- lmfit$model
    outcome <- names(lmfit$model)[1]
    predictors <- names(lmfit$model)[-1]

    theta.fit <- function(x,y){lsfit(x,y)}
    theta.predict <- function(fit,x){cbind(1,x)%*%fit$coef} 
    X <- as.matrix(mydata[predictors])
    y <- as.matrix(mydata[outcome]) 

    results <- crossval(X,y,theta.fit,theta.predict,ngroup=ngroup)
    raw_rsq <- cor(y, lmfit$fitted.values)**2 # raw R2 
    cv_rsq <- cor(y,results$cv.fit)**2 # cross-validated R2

    c(raw_rsq=raw_rsq, cv_rsq=cv_rsq)

So using the data from before

# sample data
x <- rnorm(100)
z <- rnorm(100)
y <- rnorm(100, x+z)
mydata <- data.frame(x,y,z)

We can fit a linear model and call the cross validation function:

# fit and call function
lmfit <- lm(y ~ x + z, mydata)
k_fold_rsq(lmfit, ngroup=30)

And get the resulting raw and cross-validated r-square:

  raw_rsq    cv_rsq 
0.7237907 0.7050297

Caveat: While raw_rsq is clearly correct and cv_rsq is in the ball park that I expect, note that I haven't yet examined exactly what the crosval function does. So use at your own risk and if anyone has any feedback, it would be most welcome. It's also only designed for linear models with an intercept and standard main effects notation.

Jeromy Anglim Avatar answered Oct 01 '22 10:10

Jeromy Anglim

I wrote a function for doing this. It also works for nominal predictors. It only works for lm objects (I think), but could easily be expanded to glm etc.

# from
# http://stackoverflow.com/a/16030020/3980197
# via http://www.statmethods.net/stats/regression.html

#' Calculate k fold cross validated r2
#' Using k fold cross-validation, estimate the true r2 in a new sample. This is better than using adjusted r2 values.
#' @param lmfit (an lm fit) An lm fit object.
#' @param folds (whole number scalar) The number of folds to use (default 10).
#' @export
#' @examples
#' fit = lm("Petal.Length ~ Sepal.Length", data = iris)
#' MOD_k_fold_r2(fit)
MOD_k_fold_r2 = function(lmfit, folds = 10, runs = 100, seed = 1) {

  #get data
  data = lmfit$model

  if (!is.na(seed)) set.seed(seed)

  v_runs = sapply(1:runs, FUN = function(run) {
    #Randomly shuffle the data
    data2 = data[sample(nrow(data)), ]

    #Create n equally size folds
    folds_idx <- cut(seq(1, nrow(data2)), breaks = folds, labels = FALSE)

    #Perform n fold cross validation
    sapply(1:folds, function(i) {
      #Segement your data by fold using the which() function

      test_idx = which(folds_idx==i, arr.ind=TRUE)
      test_data = data2[test_idx, ]
      train_data = data2[-test_idx, ]

      if ("(weights)" %in% data) {
        wtds = train_data[["(weights)"]]
      } else {
        train_data$.weights = rep(1, nrow(train_data))

      fit = lm(formula = lmfit$call$formula, data = train_data, weights = .weights)

      preds = predict(fit, newdata = test_data)

      #correlate to get r2
      cor(preds, test_data[[1]], use = "p")^2
    }) %>%

  c("raw_r2" = summary(lmfit)$r.squared, "cv_r2" = mean(v_runs))

Testing it:

fit = lm("Petal.Length ~ Species", data = iris)
#>    raw_r2     cv_r2 
#> 0.9413717 0.9398156 

And on the OP sample:

> MOD_k_fold_r2(lmfit)
#raw_r2  cv_r2 
# 0.724  0.718 
CoderGuy123 Avatar answered Oct 01 '22 09:10

