Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Java Apache Spark: Long transformation chains result in quadratic time

I have a Java program using Apache Spark. The most interesting part of the program looks like this:

long seed = System.nanoTime();

JavaRDD<AnnotatedDocument> annotated = documents
    .mapPartitionsWithIndex(new InitialAnnotater(seed), true);
annotated.cache();

for (int iter = 0; iter < 2000; iter++) {
    GlobalCounts counts = annotated
        .mapPartitions(new GlobalCounter())
        .reduce((a, b) -> a.sum(b)); // update overall counts (*)

    seed = System.nanoTime();

    // copy overall counts which CountChanger uses to compute a stochastic thing (**)
    annotated = annotated
        .mapPartitionsWithIndex(new CountChanger(counts, seed),  true); 
    annotated.cache();

    // adding these lines causes constant time complexity like i want
    //List<AnnotatedDocument> ll = annotated.collect();
    //annotated = sc.parallelize(ll, 8); 
}

So in effect, the line (**) results in an RDD with the form

documents
    .mapPartitionsWithIndex(initial)
    .mapPartitionsWithIndex(nextIter)
    .mapPartitionsWithIndex(nextIter)
    .mapPartitionsWithIndex(nextIter)
    ... 2000 more

a very long chain of maps indeed. In addition, line (*) forces computation (non-lazy) at each iteration as counts need to be updated.

The problem I have is that I get a time complexity that increases linearly with each iteration, and so quadratic overall:

enter image description here

I think this is because Spark tries to "remember" every RDD in the chain, and the fault tolerance algorithm or whatever is causing this to grow. However, I really have no idea.

What I'd really like to do is at each iteration tell Spark to "collapse" the RDD so that only the last one is kept in memory and worked on. This should result in constant time per iteration, I think. Is this possible? Are there any other solutions?

Thanks!

like image 728
bombax Avatar asked Mar 21 '16 09:03

bombax


2 Answers

Try using rdd.checkpoint. This will save RDD to hdfs and clear lineage.

Each time you transform an RDD you grow the lineage and Spark has to track what is available and what has to be re-computed. Processing the DAG is expensive and large DAGs tend to kill performance quite quickly. By "checkpointing" you instruct Spark to compute and save resulting RDD and discard the information of how it got created. This makes it similar to simply saving an RDD and reading it back which minimizes DAG operation.

On a sidenote, since you hit this issue, it is good to know that union also impacts RDD performance by adding steps and could also throw a StackOverflowError due to the way lineage information is . See this post

This link has more details with nice diagrams and the subject is also mentioned in this SO post.

like image 105
Ioannis Deligiannis Avatar answered Sep 19 '22 00:09

Ioannis Deligiannis


That's a really interesting question and there are a few things to consider.

Fundamentally this is an iterative algorithm, if you look at some of the different iterative machine learning algorithms in Spark you can see some approaches to working with this kind of problem.

The first thing that most of them don't cache on each iteration - rather they have a configurable caching interval. I'd probably start by caching every 10 iterations and seeing how that goes.

The other issue becomes the lineage graph, each mapPartitions you do is growing the graph a little more. At some point keeping track of that data is going to start to become more and more expensive. checkpoint allows you to have Spark write the current RDD to persistant storage and discard the lineage information. You could try doing this at some interval like every 20 iterations and seeing how this goes.

The 10 and 20 numbers are just sort of basic starting points, they depend on how slow it is to compute the data for each individual iteration and you can play with them to find the right tuning for your job.

like image 44
Holden Avatar answered Sep 21 '22 00:09

Holden