Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Method to operate on each row of data.table without using apply function

Tags:

r

data.table

I wrote a simple function below:

mcs <- function(v) { ifelse(sum((diff(sort(v)) > 6) > 0), NA, sd(v)) }

It is supposed to take a vector, sort it and then check if there is difference greater than 6 in each successive difference. It returns NA if there is a difference greater than 6 and the standard deviation if there is not.

I would like to apply this function across all rows of a data table (choosing only certain columns) and then append the return value for each row as a new column entry to the data table.

For example, given a data table like so

> dat <- data.table(A=c(1,2,3,4,5), B=c(2,3,4,10,6), C=c(3,4,10,6,8),   
D=c(3,3,3,3,3))  
> dat  
   A  B  C D  
1: 1  2  3 3  
2: 2  3  4 3  
3: 3  4 10 3  
4: 4 10  6 3  
5: 5  6  8 3  

I would like to generate the output below. (I applied function on column 2, 3, and 4 of each row.)

> dat
   A  B  C D        sd
1: 1  2  3 3 0.5773503
2: 2  3  4 3 0.5773503
3: 3  4 10 3 3.7859389
4: 4 10  6 3 3.5118846
5: 5  6  8 3 2.5166115

I learned that by row operation can be done with data tables using the following method:

> dat[, sd:=apply(.SD, 1, mcs), .SDcols=(c(2,3,4))]

And this method works except that it is too slow. I have to perform this operation on several large data tables and I wrote a script to do so. However, it only works for smaller data tables. For the tables with ~300,000 rows it finishes in a few seconds, but when I try to do so with a table that has ~800 million rows, my program doesn't finish. I've tried waiting for two hours and I think R breaks or something because the console just freezes. I've tried running the script several times and it always finishes the first few smaller tables correctly (I had the program write the table to a file to check) but when it reaches the large data table, it never finishes. I am running this on a computing cluster so I definitely don't think this is a hardware limitation. Probably poor code.

I am assuming the bottleneck is the looping done in apply, but I don't how to make it faster. I am pretty new to R so I am not sure how to optimize my code. I've seen a lot of posts around the Internet about vectorizing and I am thinking maybe if I could simultaneously apply my function to every row it would be much faster, but I don't know how to do that. Please help.

Edit
Sorry, I made a mistake in copying my mcs function. I have updated it.

Edit 2
For those interested, I ended up splitting the table in half and operating on each half separately and that worked for me.

like image 453
CuriousTim Avatar asked Dec 10 '17 08:12

CuriousTim


People also ask

How do you select rows in a data table?

You can also click anywhere in the table, and then press CTRL+A to select the table data in the entire table, or you can click the top-left most cell in the table, and then press CTRL+SHIFT+END.

What does the FOR EACH ROW activity do?

The For Each Row loop works by iterating through a DataTable, row by row, and executing whatever actions are placed within the body of the loop. For example, you could use a For Each Row loop on a data table of new students to check whether they have enrolled in classes or not.


1 Answers

If you really need speed, as always it's best to move to C++ using Rcpp, which gives us a solution that's over 100x faster.

Data

I did make some different example data to test this on that had 1000 rows instead of 5:

set.seed(123)
dat <- data.table(A = rnorm(1e3, sd=4), B = rnorm(1e3, sd=4), C = rnorm(1e3, sd=4),
                  D = rnorm(1e3, sd=4), E = rnorm(1e3, sd=4))

Solution

I used the following C++ code to do the same thing as your function, but now the looping is done in C++ instead of R through apply which saves considerable time:

#include <Rcpp.h>

using namespace Rcpp;

// [[Rcpp::export]]
NumericVector mcs2(DataFrame x) {
    int n = x.nrows();
    int m = x.size();
    NumericMatrix mat(n, m);
    for ( int j = 0; j < m; ++j ) {
        mat(_, j) = NumericVector(x[j]);
    }
    NumericVector result(n);
    for ( int i = 0; i < n; ++i ) {
        NumericVector tmp = mat(i, _);
        std::sort(tmp.begin(), tmp.end());
        bool do_sd = true;
        for ( int j = 1; j < m; ++j ) {
            if ( tmp[j] - tmp[j-1] > 6.0 ) {
                result[i] = NA_REAL;
                do_sd = false;
                break;
            }
        }
        if ( do_sd ) {
            result[i] = sd(tmp);
        }
        do_sd = true;
    }
    return result;
}

We can make sure it's returning the same values:

all.equal(apply(dat[, 2:4], 1, mcs1), mcs2(dat[,2:4]))

[1] TRUE

Now let's benchmark:

benchmark(mcs1 = dat[, sd:=apply(.SD, 1, mcs1), .SDcols=(c(2,3,4))],
          mcs2 = dat[, sd:=mcs2(.SD), .SDcols=(c(2,3,4))],
          order = 'relative',
          columns = c('test', 'elapsed', 'relative', 'user.self'))


  test elapsed relative user.self
2 mcs2    0.19    1.000     0.183
1 mcs1   21.34  112.316    20.044

How to compile this code

As an introduction to using C++ code through Rcpp, I'd suggest this chapter of Hadley Wickham's Advanced R. If you intend on doing anything further with Rcpp I'd strongly recommend you also read the official documentation and vignettes, but Wickham's book is probably a little more beginner friendly to use as a starting point. For your purposes, you just need to get Rcpp up and running so that you can compile the code above.

For this code to work for you, you'll need the Rcpp package if you don't already have it. You can get the package by running

install.packages(Rcpp)

from R. Note you'll also need a compiler; if you're on a Debian-based Linux system such as Ubuntu you can run

sudo apt install r-base-dev

from the terminal. If you are on Mac or Windows, check here for some instructions on getting this set up, or in the Wickham chapter linked above.

Once you have Rcpp installed, save the C++ code above into a file. Let's say for our example the file is named "SOanswer.cpp". Then you can make its mcs2() function available from R by putting the following two lines in your R script:

library(Rcpp)
sourceCpp("SOanswer.cpp") # assuming the file is in your working directory

That's it! Now your R script can call mcs2() and run much faster. If you want to learn more about Rcpp, beside the Wickham chapter above, I'd check out the reference manual and the vignettes available here, this page from RStudio (which also has tons of links, some of which are linked to here), and you can also find some really useful stuff looking around the Rcpp gallery.

like image 148
duckmayr Avatar answered Oct 11 '22 08:10

duckmayr