Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

data.table grouping column is length 1 in "J"

As I'm learning data.table, I found a situation I can't elegantly work around.

Up front: the absurdity of the lm formula is obvious, I'm trying to determine if this nuance can be easily worked around with a keyword or special operator within the data.table ecosystem.

library(data.table)
mt <- as.data.table(mtcars)
mt[, list(model = list(lm(mpg ~ disp))), by = "cyl"]
#    cyl model
# 1:   6  <lm>
# 2:   4  <lm>
# 3:   8  <lm>
mt[, list(model = list(lm(mpg ~ disp + cyl))), by = "cyl"]
# Error in model.frame.default(formula = mpg ~ disp + cyl, drop.unused.levels = TRUE) : 
#   variable lengths differ (found for 'cyl')

This is because inside the block, cyl is a vector of length 1, not a column like the rest of the values:

mt[, list(model = { browser(); list(lm(mpg ~ cyl+disp)); }), by = "cyl"]
# Called from: `[.data.table`(mt, , list(model = {
#     browser()
#     list(lm(mpg ~ cyl + disp))
#   ...
# Browse[1]> 
# debug at #1: list(lm(mpg ~ cyl + disp))
# Browse[2]> 
disp
# [1] 160.0 160.0 258.0 225.0 167.6 167.6 145.0
# Browse[2]> 
cyl
# [1] 6

The most straight-forward appears to be to manually lengthen it internally as a temporary variable or literally where needed:

mt[, list(model = { cyl2 <- rep(cyl, nrow(.SD)); list(lm(mpg ~ cyl2+disp)); }), by = "cyl"]
mt[, list(model = list(lm(mpg ~ rep(cyl, nrow(.SD))+disp))), by = "cyl"]

Q: Is there a more elegant way to deal with this?


Various loosely-related questions, seeding my curiosity (towards embedding "stuff" in DT objects):

  • Setting column name in "group by" operation with data.table
  • Run a function inside data.table in R
  • Using data.table to create a column of regression coefficients

Candidates so far, many good:

mt[, .(model = .(lm(mpg ~ cyl + disp, data = mt[.I]))), by = .(cyl)]
mt[, .(model = .(lm(mpg ~ cyl + disp))), by =.(cylgroup=cyl)]
mt[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=names(mt)]
mt[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=TRUE]
mt[, .(model = .(lm(mpg ~ cyl + disp, data = cbind(.SD, as.data.table(.BY))))), by = "cyl"]
like image 201
r2evans Avatar asked Oct 31 '18 22:10

r2evans


1 Answers

Thanks to all for the candidates.

mt[, .(model = .(lm(mpg ~ cyl + disp, data = mt[.I]))), by = .(cyl)]
mt[, .(model = .(lm(mpg ~ cyl + disp))), by =.(cylgroup=cyl)]
mt[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=names(mt)]
mt[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=TRUE]
mt[, .(model = .(lm(mpg ~ cyl + disp, data = cbind(.SD, as.data.table(.BY))))), by = "cyl"]

The performance (with this small model) seems to have some small differences:

library(microbenchmark)
microbenchmark(
  c1 = mt[, .(model = .(lm(mpg ~ cyl + disp, data = mt[.I]))), by = .(cyl)],
  c2 = mt[, .(model = .(lm(mpg ~ cyl + disp))), by =.(cylgroup=cyl)],
  c3 = mt[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=names(mt)],
  c4 = mt[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=TRUE],
  c5 = mt[, .(model = .(lm(mpg ~ cyl + disp, data = cbind(.SD, as.data.table(.BY))))), by = "cyl"]
)
# Unit: milliseconds
#  expr    min      lq     mean  median      uq     max neval
#    c1 3.7328 4.21745 4.584591 4.43485 4.57465  9.8924   100
#    c2 2.6740 3.11295 3.244856 3.21655 3.28975  5.6725   100
#    c3 2.8219 3.30150 3.618646 3.46560 3.81250  6.8010   100
#    c4 2.9084 3.27070 3.620761 3.44120 3.86935  6.3447   100
#    c5 5.6156 6.37405 6.832622 6.54625 7.03130 13.8931   100

With larger data

mtbigger <- rbindlist(replicate(1000, mtcars, simplify=FALSE))
microbenchmark(
  c1 = mtbigger[, .(model = .(lm(mpg ~ cyl + disp, data = mtbigger[.I]))), by = .(cyl)],
  c2 = mtbigger[, .(model = .(lm(mpg ~ cyl + disp))), by =.(cylgroup=cyl)],
  c3 = mtbigger[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=names(mtbigger)],
  c4 = mtbigger[, .(model = .(lm(mpg ~ cyl + disp, .SD))), by=cyl, .SDcols=TRUE],
  c5 = mtbigger[, .(model = .(lm(mpg ~ cyl + disp, data = cbind(.SD, as.data.table(.BY))))), by = "cyl"]
)
# Unit: milliseconds
#  expr     min       lq     mean  median       uq      max neval
#    c1 27.1635 30.54040 33.98210 32.2859 34.71505  76.5064   100
#    c2 23.9612 25.83105 28.97927 27.5059 30.02720  67.9793   100
#    c3 25.7880 28.27205 31.38212 30.2445 32.79030 105.4742   100
#    c4 25.6469 27.84185 30.52403 29.8286 32.60805  37.8675   100
#    c5 29.2477 32.32465 35.67090 35.0291 37.90410  68.5017   100

(I'm guessing the relative performance scales similarly. A better adjudication might include much wider data.)

By median runtime alone, it looks like the top (by a very small margin) is:

mtbigger[, .(model = .(lm(mpg ~ cyl + disp))), by =.(cylgroup=cyl)]
like image 173
r2evans Avatar answered Nov 14 '22 13:11

r2evans