Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fitting different models to each subset of data in R

Tags:

r

data.table

I have a large dataset with multiple classes. My aim to fit a model to each class, and then predict the results and visualize them for each class in a facet.

For a reproducible example, I have created something basic using mtcars. This works well for a simple one regression model for each class.

mtcars = data.table(mtcars)
model = mtcars[, list(fit = list(lm(mpg~disp+hp+wt))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)

However, I would like to try something like below, which does not yet work. This try is with a list of formula, but I am also looking to send different models (some glms, a few trees) to each subset of data.

mtcars = data.table(mtcars)
factors = list(c("disp","wt"), c("disp"), c("hp"))
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))
model = mtcars[, list(fit = list(lm(form))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)
like image 437
Divi Avatar asked Feb 07 '23 06:02

Divi


2 Answers

Here's an approach where we set up predict for each model as an unevaluated list, evaluate them within the data.table object, gather the output, and pass it into ggplot:

models = quote(list(
      predict(lm(form[[1]], .SD)),
      predict(lm(form[[2]], .SD)), 
      predict(lm(form[[3]], .SD))))

d <- mtcars
d[, c("est1", "est2", "est3") := eval(models), by = cyl]
d <- tidyr::gather(d, key = model, value = pred, est1:est3)

library(ggplot2)
ggplot(d, aes(x = mpg, y = pred)) + geom_line() + facet_grid(cyl ~ model)

Output:

enter image description here

like image 145
Weihuang Wong Avatar answered Feb 08 '23 23:02

Weihuang Wong


lm() accepts formula as a character vector as well. I'd therefore simply create form as:

form = lapply(factors, function(x) paste("mpg~", paste(x, collapse="+")))

And, you will need to supply the correct data (corresponding to each group using the inbuilt special symbol .SD):

model = mtcars[, list(fit=lapply(form, lm, data=.SD)), keyby=cyl]

For each cyl, form is looped through, and the corresponding formula is passed as the first argument to lm each time along with data = .SD, where .SD stands for Subset of Data and is itself a data.table. You can read more about it from the vignettes.


If you also want to have the formula in the result, then:

chform = unlist(form)
model = mtcars[, list(form=chform, fit=lapply(form, lm, data=.SD)), keyby = cyl]

HTH

PS: Please read this post if you plan to use update() within [...] using data.tables.

like image 27
Arun Avatar answered Feb 09 '23 00:02

Arun