Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is standard R median function so much slower than a simple C++ alternative?

I made the following implementation of the median in C++ and and used it in R via Rcpp:

// [[Rcpp::export]]
double median2(std::vector<double> x){
  double median;
  size_t size = x.size();
  sort(x.begin(), x.end());
  if (size  % 2 == 0){
      median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
  }
  else {
      median = x[size / 2];
  }
  return median;
}

If I subsequently compare the performance with the standard built-in R median function, I get the following results via microbenchmark

> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
       expr    min     lq     mean median     uq     max neval
  median(x) 25.469 26.990 34.96888 28.130 29.081 518.126   100
 median2(x)  1.140  1.521  2.47486  1.901  2.281  47.897   100

Why is the standard median function so much slower? This isn't what I would expect...

like image 893
Ruben Avatar asked Jan 13 '16 15:01

Ruben


People also ask

How does the median function work in R?

The median() Function in RThe function accepts a vector as an input. If there are an odd number of values in the vector, the function returns the middle value. If there are an even number of values in the vector, the function returns the average of the two medians.

Why do we use median?

The median is the most informative measure of central tendency for skewed distributions or distributions with outliers. For example, the median is often used as a measure of central tendency for income distributions, which are generally highly skewed.

What is the median in statistics?

The median is the middle value when a data set is ordered from least to greatest. The mode is the number that occurs most often in a data set.

Which is used to calculate median?

For a small data set, you first count the number of data points (n) and arrange the data points in increasing order. If the number of data points is uneven, you add 1 to the number of points and divide the results by 2 to get the rank of the data point whose value is the median.


1 Answers

As noted by @joran, your code is very specialized, and generally speaking, less generalized functions, algorithms, etc... are often more performant. Take a look at median.default:

median.default
# function (x, na.rm = FALSE) 
# {
#   if (is.factor(x) || is.data.frame(x)) 
#     stop("need numeric data")
#   if (length(names(x))) 
#     names(x) <- NULL
#   if (na.rm) 
#     x <- x[!is.na(x)]
#   else if (any(is.na(x))) 
#     return(x[FALSE][NA])
#   n <- length(x)
#   if (n == 0L) 
#     return(x[FALSE][NA])
#   half <- (n + 1L)%/%2L
#   if (n%%2L == 1L) 
#     sort(x, partial = half)[half]
#   else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }

There are several operations in place to accommodate the possibility of missing values, and these will definitely impact the overall execution time of the function. Since your function does not replicate this behavior it can eliminate a bunch of calculations, but consequently will not provide the same result for vectors with missing values:

median(c(1, 2, NA))
#[1] NA

median2(c(1, 2, NA))
#[1] 2

A couple of other factors which probably don't have as much of an effect as the handling of NAs, but are worth pointing out:

  • median, along with a handful of the functions it uses, are S3 generics, so there is a small amount of time spent on method dispatch
  • median will work with more than just integer and numeric vectors; it will also handle Date, POSIXt, and probably a bunch of other classes, and preserve attributes correctly:

median(Sys.Date() + 0:4)
#[1] "2016-01-15"

median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"

Edit: I should mention that the function below will cause the original vector to be sorted since NumericVectors are proxy objects. If you want to avoid this, you can either Rcpp::clone the input vector and operate on the clone, or use your original signature (with a std::vector<double>), which implicitly requires a copy in the conversion from SEXP to std::vector.

Also note that you can shave off a little more time by using a NumericVector instead of a std::vector<double>:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
  std::size_t size = x.size();
  std::sort(x.begin(), x.end());
  if (size  % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
  return x[size / 2];
}

microbenchmark::microbenchmark(
  median(x),
  median2(x),
  cpp_med(x),
  times = 200L
)
# Unit: microseconds
#       expr    min      lq      mean  median      uq     max neval
#  median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810   200
# median2(x)  6.474  7.9665  13.90126 11.0570  14.844 151.817   200
# cpp_med(x)  5.737  7.4285  11.25318  9.0270  13.405  52.184   200

Yakk brought up a great point in the comments above - also elaborated on by Jerry Coffin - about the inefficiency of doing a complete sort. Here's a rewrite using std::nth_element, benchmarked on a much larger vector:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
  Rcpp::NumericVector x = Rcpp::clone(xx);
  std::size_t n = x.size() / 2;
  std::nth_element(x.begin(), x.begin() + n, x.end());

  if (x.size() % 2) return x[n]; 
  return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}

set.seed(123)
xx <- rnorm(10e5)

all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))

microbenchmark::microbenchmark(
  cpp_med2(xx), median2(xx), 
  median(xx), times = 200L
)
# Unit: milliseconds
#         expr      min       lq     mean   median       uq       max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161  33.92103   200
#  median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301   200
#   median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939   200
like image 77
nrussell Avatar answered Nov 15 '22 03:11

nrussell