Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Applying a custom function on data.table instead of using plyr and ddply

Tags:

r

data.table

plyr

I'm processing a data.table named orderFlow and compute potentialWelfare.tmp as the output. So far, the following, plyr-based approach has been my solution, but due to the input orderFlow having millions of rows I would prefer a solution leveraging the performance of data.table in R.

    # solution so far, poor performance on huge orderFlow input data.table
    require(plyr)
    potentialWelfare.tmp = ddply(orderFlow, 
                       .variables = c("simulationrun_id", "db"), 
                       .fun = calcPotentialWelfare, 
                       .progress = "text", 
                       .parallel=TRUE)

Edit1: Briefly, the custom function checks whether there are more bids or asks in df and sums the valuation of NbAsks sorted (by valuation) bids. This is done to select the most valuable bids and sum up their valuation. The code is legacy, probably not very efficient, but it served its purpose in combination with plyr and ordinary data.frames.

    calcPotentialWelfare <- function(df){
       NbAsks = dim(df[df$type=="ask",])[1]
    #   print(NbAsks)
      Bids = df[df$type == "bid",]
    #         dd[with(dd, order(-z, b)), ]
      Bids = Bids[with(Bids,order(valuation,decreasing = TRUE)),]
      NbBids = dim(df[df$type == "bid",])[1]
    #   print(Bids)
      if (NbAsks > 0){
        Bids = Bids[1:min(NbAsks,NbBids),]
        potentialWelfare = sum(Bids$valuation)
        return(potentialWelfare)
      }
      else{
        potentialWelfare = 0
        return(potentialWelfare)
      }
    }

Unfortunately, I cannot find a working way to implement this using data.table. What I have arrived at so far using the ?data.table and the corresponding FAQ is this:

    #   trying to use data.table, but it doesn't work so far.
    potentialWelfare.tmp = orderFlow[, lapply(.SD, calcPotentialWelfare), by = list(simulationrun_id, db),.SDcols=c("simulationrun_id", "db")]

what I get is

    Error in `[.data.frame`(orderFlow, , lapply(.SD, calcPotentialWelfare),  : unused arguments (by = list(simulationrun_id, db), .SDcols = c("simulationrun_id", "db"))

Here is the input:

    > head(orderFlow)
      type  valuation price               dateCreation                    dateDue                dateMatched id
    1  ask 0.30000000   0.3 2012-01-01 00:00:00.000000 2012-01-01 00:30:00.000000 2012-01-01 00:01:01.098307  1
    2  bid 0.39687633   0.0 2012-01-01 00:01:01.098307 2012-01-01 00:10:40.024807 2012-01-01 00:01:01.098307  2
    3  bid 0.96803384    NA 2012-01-01 00:03:05.660811 2012-01-01 00:06:26.368941                       <NA>  3
    4  bid 0.06163186    NA 2012-01-01 00:05:25.413959 2012-01-01 00:09:06.189893                       <NA>  4
    5  bid 0.57017143    NA 2012-01-01 00:10:10.344876 2012-01-01 00:57:58.998516                       <NA>  5
    6  bid 0.37188442    NA 2012-01-01 00:11:25.761372 2012-01-01 00:43:24.274176                       <NA>  6
              created_at updated_at simulationrun_id db
    1 2013-12-10 14:37:29.065634         NA             7004  1
    2 2013-12-10 14:37:29.065674         NA             7004  1
    3 2013-12-10 14:37:29.065701         NA             7004  1
    4 2013-12-10 14:37:29.065726         NA             7004  1
    5 2013-12-10 14:37:29.065750         NA             7004  1
    6 2013-12-10 14:37:29.065775         NA             7004  1

I am expecting something like this as the output, i.e., the function calcPotentialWelfare is aggregating data in a somehow special way from the column 'valuation' of the data.table orderFlow.

    > head(potentialWelfare.tmp)
      simulationrun_id db potentialWelfare
    1                1  1         16.86684
    2                2  1         18.44314
    3                4  1         16.86684
    4                5  1         18.44314
    5                7  1         16.86684
    6                8  1         18.44314

Really excited to see this get solved. Thanks for reading!

Edit2:

    > dput(head(orderFlow))
    structure(list(type = c("ask", "bid", "bid", "bid", "bid", "bid"
    ), valuation = c(0.3, 0.39687632952068, 0.968033835246625, 0.0616318564942726, 
    0.570171430446081, 0.371884415116724), price = c(0.3, 0, NA, 
    NA, NA, NA), dateCreation = c("2012-01-01 00:00:00.000000", "2012-01-01 00:01:01.098307", 
    "2012-01-01 00:03:05.660811", "2012-01-01 00:05:25.413959", "2012-01-01 00:10:10.344876", 
    "2012-01-01 00:11:25.761372"), dateDue = c("2012-01-01 00:30:00.000000", 
    "2012-01-01 00:10:40.024807", "2012-01-01 00:06:26.368941", "2012-01-01 00:09:06.189893", 
    "2012-01-01 00:57:58.998516", "2012-01-01 00:43:24.274176"), 
        dateMatched = c("2012-01-01 00:01:01.098307", "2012-01-01 00:01:01.098307", 
        NA, NA, NA, NA), id = 1:6, created_at = c("2013-12-10 14:37:29.065634", 
        "2013-12-10 14:37:29.065674", "2013-12-10 14:37:29.065701", 
        "2013-12-10 14:37:29.065726", "2013-12-10 14:37:29.065750", 
        "2013-12-10 14:37:29.065775"), updated_at = c(NA_real_, NA_real_, 
        NA_real_, NA_real_, NA_real_, NA_real_), simulationrun_id = c(7004L, 
        7004L, 7004L, 7004L, 7004L, 7004L), db = c(1L, 1L, 1L, 1L, 
        1L, 1L)), .Names = c("type", "valuation", "price", "dateCreation", 
    "dateDue", "dateMatched", "id", "created_at", "updated_at", "simulationrun_id", 
    "db"), row.names = c(NA, 6L), class = "data.frame")
like image 799
Peter Lustig Avatar asked Dec 16 '13 21:12

Peter Lustig


1 Answers

I think this should be faster. There are some errors in the way you use data.table. I suggest you read through the introduction, go through examples, and read the FAQ.

calcPotentialWelfare <- function(dt){
  NbAsks = nrow(dt["ask", nomatch=0L]) # binary search based subset/join - very fast
  Bids   = dt["bid", nomatch=0L] # binary search based subset/join - very fast
  NbBids = nrow(Bids)
  # for each 'type', the 'valuation' will always be sorted, 
  # but in ascending order - but you need descending order
  # so you can just use the function 'tail' to fetch the last 'n' items... as follows.
  if (NbAsks > 0) return(sum(tail(Bids, min(NbAsks, NbBids))$valuation))
  else return(0)
}

# setkey on 'type' column to use binary search based subset/join in the function
# also on valuation so that we don't have to 'order' for every group 
# inside the function - we can use 'tail'
setkey(orderFlow, type, valuation) 
potentialWelfare.tmp =
  orderFlow[, calcPotentialWelfare(.SD), 
            by=.(simulationrun_id, db),
            .SDcols=c("type", "valuation")]

.SD is a special variable that creates a data.table for each grouping with all the columns that are not mentioned in by= (if .SDcols are not specified). If .SDcols are specified, then, .SD is created for each groupw with just those columns specified, with data corresponding to that group.

Using lapply(.SD, ...) provides each column to the function, which is not what you need. You need to send the whole data to the function. However, since you need only the columns 'type' and 'valuation' inside your function, you can speed it up by providing .SDcols=c('type', 'valuation'). This'll save a lot of time by ignoring the other columns.

like image 196
Arun Avatar answered Oct 19 '22 23:10

Arun