Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why the built-in lm function is so slow in R?

I always thought that the lm function was extremely fast in R, but as this example would suggest, the closed solution computed using the solve function is way faster.

data<-data.frame(y=rnorm(1000),x1=rnorm(1000),x2=rnorm(1000))
X = cbind(1,data$x1,data$x2)

library(microbenchmark)
microbenchmark(
solve(t(X) %*% X, t(X) %*% data$y),
lm(y ~ .,data=data))

Can someone explain me if this toy example is a bad example or it is the case that lm is actually slow?

EDIT: As suggested by Dirk Eddelbuettel, as lm needs to resolve the formula, the comparison is unfair, so better to use lm.fit which doesn't need to resolve the formula

microbenchmark(
solve(t(X) %*% X, t(X) %*% data$y),
lm.fit(X,data$y))


Unit: microseconds
                           expr     min      lq     mean   median       uq     max neval cld
solve(t(X) %*% X, t(X) %*% data$y)  99.083 108.754 125.1398 118.0305 131.2545 236.060   100  a 
                      lm.fit(X, y) 125.136 136.978 151.4656 143.4915 156.7155 262.114   100   b
like image 522
adaien Avatar asked Apr 12 '16 11:04

adaien


People also ask

What does the R function lm () do?

The lm() function is used to fit linear models to data frames in the R Language. It can be used to carry out regression, single stratum analysis of variance, and analysis of covariance to predict the value corresponding to data that is not in the data frame.

What algorithm does lm use in R?

lm uses the QR factorization method (a direct rather than iterative method) to solve linear least squares problems. The documentation for lm shows that it solves linear least squares problems and uses QR factorization to do it.

What does lm output in R?

In R, the lm summary produces the standard deviation of the error with a slight twist. Standard deviation is the square root of variance. Standard Error is very similar. The only difference is that instead of dividing by n-1, you subtract n minus 1 + # of variables involved.

What is the general format of the lm () function?

This function uses the following basic syntax: lm(formula, data, …) where: formula: The formula for the linear model (e.g. y ~ x1 + x2)


1 Answers

You are overlooking that

  • solve() only returns your parameters
  • lm() returns you a (very rich) object with many components for subsequent analysis, inference, plots, ...
  • the main cost of your lm() call is not the projection but the resolution of the formula y ~ . from which the model matrix needs to be built.

To illustrate Rcpp we wrote a few variants of a function fastLm() doing more of what lm() does (ie a bit more than lm.fit() from base R) and measured it. See e.g. this benchmark script which clearly shows that the dominant cost for smaller data sets is in parsing the formula and building the model matrix.

In short, you are doing the Right Thing by using benchmarking but you are doing it not all that correctly in trying to compare what is mostly incomparable: a subset with a much larger task.

like image 146
Dirk Eddelbuettel Avatar answered Sep 26 '22 04:09

Dirk Eddelbuettel