Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find the largest n unique values and their frequencies in R and Rcpp

I have a numeric vector v (with already omitted NAs) and want to get the nth largest values and their respective frequencies.

I found http://gallery.rcpp.org/articles/top-elements-from-vectors-using-priority-queue/ to be quite fast.

// [[Rcpp::export]]
std::vector<int> top_i_pq(NumericVector v, unsigned int n)
{

typedef pair<double, int> Elt;
priority_queue< Elt, vector<Elt>, greater<Elt> > pq;
vector<int> result;

for (int i = 0; i != v.size(); ++i) {
    if (pq.size() < n)
      pq.push(Elt(v[i], i));
    else {
      Elt elt = Elt(v[i], i);
      if (pq.top() < elt) {
        pq.pop();
        pq.push(elt);
      }
    }
  }

  result.reserve(pq.size());
  while (!pq.empty()) {
    result.push_back(pq.top().second + 1);
    pq.pop();
  }

  return result ;

}

However ties will not be respected. In fact I don't need the indices, returning the values would also be ok.

What I would like to get is a list containing the values and the frequencies, say something like:

numv <- c(4.2, 4.2, 4.5, 0.1, 4.4, 2.0, 0.9, 4.4, 3.3, 2.4, 0.1)

top_i_pq(numv, 3)
$lengths
[1] 2 2 1

$values
[1] 4.2 4.4 4.5

Neither getting a unique vector, nor a table, nor a (full) sort is a good idea, as n is usually small compared to the length of v (which might easily be >1e6).

Solutions so far are:

 library(microbenchmark)
 library(data.table)
 library(DescTools)

 set.seed(1789)
 x <- sample(round(rnorm(1000), 3), 1e5, replace = TRUE)
 n <- 5

 microbenchmark(
   BaseR = tail(table(x), n),
   data.table = data.table(x)[, .N, keyby = x][(.N - n + 1):.N],
   DescTools = Large(x, n, unique=TRUE),
   Coatless = ...
 )

Unit: milliseconds
       expr       min         lq       mean     median        uq       max neval
      BaseR 188.09662 190.830975 193.189422 192.306297 194.02815 253.72304   100
 data.table  11.23986  11.553478  12.294456  11.768114  12.25475  15.68544   100
  DescTools   4.01374   4.174854   5.796414   4.410935   6.70704  64.79134   100

Hmm, DescTools still fastest, but I'm sure it can be significantly improved by Rcpp (as it's pure R)!

like image 643
Andri Signorell Avatar asked Dec 25 '22 06:12

Andri Signorell


2 Answers

I'd like to throw my hat in the ring with another Rcpp-based solution, which is ~7x faster than the DescTools approach and ~13x faster than the data.table approach, using the 1e5-length x and n = 5 sample data above. The implementation is a bit lengthy, so I'll lead with the benchmark:

fn.dt <- function(v, n) {
    data.table(v = v)[
      ,.N, keyby = v
      ][(.N - n + 1):.N]
}

microbenchmark(
    "DescTools" = Large(x, n, unique=TRUE),
    "top_n" = top_n(x, 5),
    "data.table" = fn.dt(x, n),
    times = 500L
)
# Unit: microseconds
#        expr      min       lq      mean   median       uq       max neval
#   DescTools 3330.527 3790.035 4832.7819 4070.573 5323.155 54921.615   500
#       top_n  566.207  587.590  633.3096  593.577  640.832  3568.299   500
#  data.table 6920.636 7380.786 8072.2733 7764.601 8585.472 14443.401   500

Update

If your compiler supports C++11, you can take advantage of std::priority_queue::emplace for a (surprisingly) noticeable performance boost (compared to the C++98 version below). I won't post this version as it is mostly identical, save for a few calls to std::move and emplace, but here's a link to it.

Testing this against the previous three functions, and using data.table 1.9.7 (which is a bit faster than 1.9.6) yields

print(res2, order = "median", signif = 3)
# Unit: relative
#              expr  min    lq      mean median    uq   max neval  cld
#            top_n2  1.0  1.00  1.000000   1.00  1.00  1.00  1000    a   
#             top_n  1.6  1.58  1.666523   1.58  1.75  2.75  1000    b  
#         DescTools 10.4 10.10  8.512887   9.68  7.19 12.30  1000    c 
#  data.table-1.9.7 16.9 16.80 14.164139  15.50 10.50 43.70  1000    d 

where top_n2 is the C++11 version.


The top_n function is implemented as follows:

#include <Rcpp.h>
#include <utility>
#include <queue>

class histogram {
private:
    struct paired {
        typedef std::pair<double, unsigned int> pair_t;

        pair_t pair;
        unsigned int is_set;

        paired() 
            : pair(pair_t()),
              is_set(0)
        {}

        paired(double x)
            : pair(std::make_pair(x, 1)),
              is_set(1)
        {}

        bool operator==(const paired& other) const {
            return pair.first == other.pair.first;
        }

        bool operator==(double other) const {
            return is_set && (pair.first == other);
        }

        bool operator>(double other) const {
            return is_set && (pair.first > other);
        }

        bool operator<(double other) const {
            return is_set && (pair.first < other);
        }

        paired& operator++() {
            ++pair.second;
            return *this;
        }

        paired operator++(int) {
            paired tmp(*this);
            ++(*this);
            return tmp;
        }
    };

    struct greater {
        bool operator()(const paired& lhs, const paired& rhs) const {
            if (!lhs.is_set) return false;
            if (!rhs.is_set) return true;
            return lhs.pair.first > rhs.pair.first;
        }
    };  

    typedef std::priority_queue<
        paired,
        std::vector<paired>,
        greater
    > queue_t;

    unsigned int sz;
    queue_t queue;

    void insert(double x) {
        if (queue.empty()) {
            queue.push(paired(x));
            return;
        }

        if (queue.top() > x && queue.size() >= sz) return;

        queue_t qtmp;
        bool matched = false;

        while (queue.size()) {
            paired elem = queue.top();
            if (elem == x) {
                qtmp.push(++elem);
                matched = true;
            } else {
                qtmp.push(elem);
            }
            queue.pop();
        }

        if (!matched) {
            if (qtmp.size() >= sz) qtmp.pop();
            qtmp.push(paired(x));
        }

        std::swap(queue, qtmp);
    }

public:
    histogram(unsigned int sz_) 
        : sz(sz_), 
          queue(queue_t())
    {}

    template <typename InputIt>
    void insert(InputIt first, InputIt last) {
        for ( ; first != last; ++first) {
            insert(*first);
        }
    }

    Rcpp::List get() const {
        Rcpp::NumericVector values(sz);
        Rcpp::IntegerVector freq(sz);
        R_xlen_t i = 0;

        queue_t tmp(queue);
        while (tmp.size()) {
            values[i] = tmp.top().pair.first;
            freq[i] = tmp.top().pair.second;
            ++i;
            tmp.pop();
        }

        return Rcpp::List::create(
            Rcpp::Named("value") = values,
            Rcpp::Named("frequency") = freq);
    }
};


// [[Rcpp::export]]
Rcpp::List top_n(Rcpp::NumericVector x, int n = 5) {
    histogram h(n);
    h.insert(x.begin(), x.end());
    return h.get();
} 

There's a lot going on in the histogram class above, but just to touch on some of the key points:

  • The paired type is essentially a wrapper class around an std::pair<double, unsigned int>, which associates a value with a count, providing some convenience features such as operator++() / operator++(int) for direct pre-/post-increment of the count, and modified comparison operators.
  • The histogram class wraps a sort of "managed" priority queue, in the sense that the size of std::priority_queue is capped at a particular value sz.
  • Instead of using the default std::less ordering of std::priority_queue, I'm using a greater-than comparator so that candidate values can be checked against std::priority_queue::top() to quickly determine whether they should (a) be discarded, (b) replace the current minimum value in the queue, or (c) update the count of one of the existing values in the queue. This is only possible because the size of the queue is being restricted to <= sz.
like image 145
nrussell Avatar answered Jan 04 '23 23:01

nrussell


I'd wager data.table is competitive:

library(data.table)

data <- data.table(v)

data[ , .N, keyby = v][(.N - n + 1):.N]

where n is the number you want to get

like image 38
MichaelChirico Avatar answered Jan 04 '23 22:01

MichaelChirico