Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I apply grouped data to grouped models using broom and dplyr?

Tags:

I'd like to do the equivalent of fitting a model of gpm (gallons per mile = 1/mpg) to wt in the mtcars data set. That seems easy:

data(mtcars)
library(dplyr)
library(tidyr)
library(broom)
library(ggplot2)
library(scales)

mtcars2 <-
    mtcars %>%
    mutate(gpm = 1 / mpg) %>%
    group_by(cyl, am)

lm1 <-
    mtcars2 %>%
    do(fit = lm(gpm ~ wt, data = .))

That gets me a rowwise data frame with 6 rows, as expected.

This graph confirms that there are six groups:

p1 <-
    qplot(wt, gpm, data = mtcars2) +
    facet_grid(cyl ~ am) +
    stat_smooth(method='lm',se=FALSE, fullrange = TRUE) +
    scale_x_continuous(limits = c(0,NA)) 

I can use augment() to get the fitted outputs:

lm1 %>% augment(fit)

That gives me 32 rows, one for each row in mtcars2, as expected.

Now the challenge: I'd like to get fitted outputs using newdata, where I've incremented wt by cyl/4:

newdata <-
    mtcars2 %>%
    mutate(
        wt = wt + cyl/4)

I expect that this will produce a data frame of the same size as lm1 %>% augment(fit): one row for each row in newdata, because broom will match up models and newdata by the grouping variables cyl and am.

Unfortunately,

pred1 <-
    lm1 %>%
    augment(
        fit,
        newdata = newdata)

gives me a data frame with 192 rows (= 6 x 32), apparently fitting each model to each row of newdata.

From reading elsewhere, I gather that group_by and rowwise data frames aren't compatible, so lm1 is ungrouped, and augment can't associate models and newdata. Is there another design pattern that lets me do this? It would be nice if it were as simple and transparent as the above attempt, but it's more important that it work.

Here's my sessionInfo():

> sessionInfo()
R version 3.3.1 (2016-06-21)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 7 x64 (build 7601) Service Pack 1

locale:
[1] LC_COLLATE=English_United States.1252 
[2] LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] scales_0.4.0  ggplot2_2.1.0 broom_0.4.1   tidyr_0.6.0   dplyr_0.5.0  

loaded via a namespace (and not attached):
 [1] Rcpp_0.12.7      magrittr_1.5     mnormt_1.5-4     munsell_0.4.3   
 [5] colorspace_1.2-6 lattice_0.20-34  R6_2.1.3         stringr_1.1.0   
 [9] plyr_1.8.4       tools_3.3.1      parallel_3.3.1   grid_3.3.1      
[13] nlme_3.1-128     gtable_0.2.0     psych_1.6.9      DBI_0.5-1       
[17] lazyeval_0.2.0   assertthat_0.1   tibble_1.2       reshape2_1.4.1  
[21] labeling_0.3     stringi_1.1.1    compiler_3.3.1   foreign_0.8-67  

EDIT:

@aosmith: I have been exploring your second option, and I like it. When I try it on my real data, though, I have a problem in the mutate command: it returns "Error: augment doesn't know how to deal with data of class list".

My real code is more like:

newdata %>% 
dplyr::select(cyl, am, wt) %>% # wt holds new predictor values
group_by(cyl, am) %>%
nest() %>%
inner_join(regressions, .) %>% 
## looks like yours at this point
mutate(pred = list(augment(fit, newdata = data))) %>% # Error here
unnest(pred)

Where I say it looks like yours, I mean I have the following columns (renamed here for consistency): ID (chr), attr1 (dbl), cyl (dbl), am (chr), fit (list), and data (list). You have cyl, am (dbl), fit, and data. I changed my am to dbl, but that didn't help.

I think the difference is that I have 3 (ID ... similar to the rownames in mtcars) x 2 (cyl) x 2 (am) units in this sample (with each sample having 12 measurements), while the mtcars example has 3 (cyl) x 2 (am) cells x a random number of car types per cell. In my analysis, I need to see the ID values, but newdata applies equally to all units. If it helps, think of it as the speed of a headwind applied to each car in the test. Does that suggest a cause for augment's complaint it can't deal with data of class list?

EDIT: Merging the ID with the newdata (using full=TRUE) solved the last problem. I'm currently using your first proposed solution.

like image 614
BillH Avatar asked Oct 03 '16 17:10

BillH


1 Answers

I've used map2 from package purrr for this sort of situation. map2 loops through the elements of two lists simultaneously. The lists must be the same length and be in the same order.

The elements of the lists are used as arguments for some function you want to apply (augment, in your case). Here your two lists would be a list of models and a list of datasets (one list for each cyl/am combination).

Using map2_df returns the results as a data.frame instead of a list.

library(purrr)

I made the list of data.frames to predict with using split. The order of the factors to split on determined the list order, so I made sure it was in the same order as lm1.

test_split = split(newdata, list(newdata$am, newdata$cyl)

map2_df(lm1$fit, test_split, ~augment(.x, newdata = .y))

To avoid worrying about order so much, you could nest the prediction data by groups, join this to lm1, and return the results of augment as a list for unnesting.

newdata %>%
    group_by(cyl, am) %>%
    nest() %>%
    inner_join(lm1, .) %>%
    mutate(pred = list(augment(fit, newdata = data))) %>%
    unnest(pred)
like image 56
aosmith Avatar answered Sep 26 '22 17:09

aosmith