Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select max of absolute value for each numeric column by a grouping column in data.table or dplyr

Here is an example of what my data.frame looks like:

opts <- seq(-0.5, 0.5, 0.05)
df <- data.frame(combo1=sample(opts, 6),
                 combo2=sample(opts, 6),
                 combo3=sample(opts, 6),
                 gene=rep(c("g1", "g2", "g3"), each=2), stringsAsFactors=F)

df
   combo1 combo2 combo3 gene
1   0.40   0.50  -0.10   g1
2   0.10  -0.20  -0.35   g1
3  -0.35  -0.35   0.40   g2
4   0.00   0.10  -0.30   g2
5  -0.45  -0.10   0.05   g3
6  -0.40  -0.40  -0.05   g3

For each combo, I want to group by gene and then select the maximum absolute value. I can accomplish this using dplyr:

library(dplyr)
df_final <- data.frame(row.names=unique(df$gene))

for (combo in colnames(df)[1:3]) {

    combo_preds <- df[, c(combo, "gene")]
    colnames(combo_preds) <- c("pred", "gene")

    combo_preds %>%
        group_by(gene) %>%
        arrange(desc(abs(pred))) %>%
        slice(1) %>%
        ungroup() ->
        combo_preds

    #add to df_final
    class(combo_preds) <- "data.frame"
    df_final[combo_preds$gene, combo] <- combo_preds$pred
}
#names rows based on gene
row.names(df_final) <- unique(df$gene)

df_final
    combo1 combo2 combo3
g1   0.40   0.50  -0.35
g2  -0.35  -0.35   0.40
g3  -0.45  -0.40   0.05

Is there a way that I can accomplish the above with data.table or some other more efficient implementation? In reality, I have ~1300 dfs, each with ~14000 genes and ~650 combos. The current implementation takes 2.6 minutes per df so it will take over 2 days.

like image 241
alexvpickering Avatar asked Dec 02 '22 15:12

alexvpickering


2 Answers

You certainly can do it with data.table. (I didn't do a benchmark against your version).

library(data.table)
dt <- data.table(df)
dt[, lapply(.SD, function (col) col[which.max(abs(col))]), by='gene']

It essentially groups your table by gene, and on each chunk of table uses lapply to loop over each column to find the value with the maximum absolute value.

However, I think you might be better off reshaping your table to long format, though I think it depends on your specific data (you'll have to try and see).

opts <- seq(-0.5, 0.5, 0.05)
n.combos <- 600
n.genes <- 10000
n.rows.per.gene <- 5

# columns are called X1 X2 instead of combo1 combo2 but no matter.
df.wide <- data.frame(replicate(n.combos, sample(opts, n.rows.per.gene, replace=T)),
                      gene=rep(paste0("g", 1:n.genes), each=n.rows.per.gene))

Here df.wide looks like your dataframe, with one column per combo and one row for each replicate for each gene.

Here's the original data.table answer:

# data.table option
library(data.table)
dt <- data.table(df.wide)
system.time({
out <- dt[, lapply(.SD, function (col) col[which.max(abs(col))]), by='gene']
})
#    user  system elapsed 
#  10.757   0.364  12.612

Now we reshape to long format, which just has a 'combo' column and a 'value' column:

# reshape to long
dt.long <- melt(dt, id.vars='gene', variable.name='combo')
# > head(dt.long)
#    gene combo value
# 1:   g1    X1  0.20
# 2:   g1    X1  0.30
# 3:   g1    X1  0.10
# 4:   g1    X1  0.05
# 5:   g1    X1  0.30
# 6:   g2    X1  0.20

system.time({out.long <- dt.long[, value[which.max(value)], by='gene,combo']})

   user  system elapsed 
  8.000   0.472   9.525 

This gives you a dataframe with the gene, combo and value that was the maximum absolute. You could reshape it back to wide if you wanted.

So it doesn't seem much faster - I guess you will have to try it on your data and see. I suppose even if the second method is faster, you still have to factor the time taken to cast the datatable to long (which doesn't seem like much).

Compare to e.g. dplyr (which is very elegant to use, but slower)

system.time({
out.dplyr <- df.wide %>% group_by(gene) %>%
  summarise_each(funs(.[which.max(abs(.))]))
})
#   user  system elapsed 
# 163.106   7.989 189.788
like image 173
mathematical.coffee Avatar answered Mar 24 '23 19:03

mathematical.coffee


Here's a much simpler and faster dplyr method:

df %>% group_by(gene) %>%
  summarise_each(funs(.[which.max(abs(.))]))

Try it on reproducible data:

set.seed(495)
opts <- seq(-0.5, 0.5, 0.05)
df <- data.frame(combo1=sample(opts, 6),
                 combo2=sample(opts, 6),
                 combo3=sample(opts, 6),
                 gene=rep(c("g1", "g2", "g3"), each=2), stringsAsFactors=F)

df
  combo1 combo2 combo3 gene
1  -0.15   0.50  -0.25   g1
2  -0.45  -0.50   0.15   g1
3  -0.25   0.10  -0.30   g2
4   0.35  -0.40  -0.15   g2
5  -0.05  -0.35  -0.40   g3
6   0.15  -0.05  -0.10   g3
df %>% group_by(gene) %>%
  summarise_each(funs(.[which.max(abs(.))]))
   gene combo1 combo2 combo3
1    g1  -0.45   0.50  -0.25
2    g2   0.35  -0.40  -0.30
3    g3   0.15  -0.35  -0.40

Note that in the above case there's a tie in the absolute value for combo2 and gene=g1. If that matters, you'll need to decide how to break ties.

Timings for my dplyr method and a slightly faster version of @mathematical.coffee's data.table method (using a much larger sample data frame):

set.seed(495)
opts <- seq(-0.5, 0.5, 0.05)
df <- data.frame(combo1=sample(opts, 9e4, replace=TRUE),
                 combo2=sample(opts, 9e4, replace=TRUE),
                 combo3=sample(opts, 9e4, replace=TRUE),
                 gene=rep(c("g1", "g2", "g3"), each=3e4), stringsAsFactors=F)

microbenchmark::microbenchmark(
  dplyr=setDF(df) %>% group_by(gene) %>%
    summarise_each(funs(.[which.max(abs(.))])),
  data.table={setDT(df)[, lapply(.SD, function (col) col[which.max(abs(col))]), by='gene']}
)
Unit: milliseconds
      expr       min        lq      mean    median        uq      max neval cld
     dplyr 10.013623 11.839132 14.156735 12.284574 12.675220 32.35739   100   b
data.table  4.434841  6.008701  6.947104  6.222775  6.415083 29.52652   100  a

So the data.table version runs in about half the time of the dplyr version.

UPDATE: To address @Arun's comment, here's a much larger sample data frame with more columns and more gene categories.

# Large sample of fake data
set.seed(194)
genes=apply(expand.grid(letters,letters), 1, paste, collapse="")
df = data.frame(replicate(50, rnorm(26*26*1e3)), gene=genes)
object.size(df)
# 273 MB

microbenchmark::microbenchmark(
  dplyr=setDF(df) %>% group_by(gene) %>%
    summarise_each(funs(.[which.max(abs(.))])),
  data.table={setDT(df)[, lapply(.SD, function (col) col[which.max(abs(col))]), by='gene']},
  times=10
)
Unit: milliseconds
      expr       min        lq      mean    median        uq       max neval cld
     dplyr 1240.1695 1299.0425 1375.8298 1318.5343 1385.5854 1748.8112    10   b
data.table  464.5597  493.8959  527.7097  519.3607  585.1482  603.3916    10  a

UPDATE 2: Same as above, but with an even larger number of groups (26^3 instead of 26^2). As @Arun discussed, the data.table speed advantage increases with the larger number of groups.

   # Large sample of fake data
   set.seed(194)
   genes=apply(expand.grid(letters,letters,letters), 1, paste, collapse="")
   df = data.frame(replicate(50, rnorm(26*26*26*50)), gene=genes)
   object.size(df)
   # 356 MB

   microbenchmark::microbenchmark(
     dplyr=setDF(df) %>% group_by(gene) %>%
       summarise_each(funs(.[which.max(abs(.))])),
     data.table={setDT(df)[, lapply(.SD, function (col) col[which.max(abs(col))]), by='gene']},
     times=1
   )       
   Unit: seconds
             expr       min        lq      mean    median        uq       max neval
            dplyr 27.567790 27.567790 27.567790 27.567790 27.567790 27.567790     1        
       data.table  2.765047  2.765047  2.765047  2.765047  2.765047  2.765047     1
like image 35
eipi10 Avatar answered Mar 24 '23 20:03

eipi10