I have a Java program using Apache Spark. The most interesting part of the program looks like this:
long seed = System.nanoTime();
JavaRDD<AnnotatedDocument> annotated = documents
.mapPartitionsWithIndex(new InitialAnnotater(seed), true);
annotated.cache();
for (int iter = 0; iter < 2000; iter++) {
GlobalCounts counts = annotated
.mapPartitions(new GlobalCounter())
.reduce((a, b) -> a.sum(b)); // update overall counts (*)
seed = System.nanoTime();
// copy overall counts which CountChanger uses to compute a stochastic thing (**)
annotated = annotated
.mapPartitionsWithIndex(new CountChanger(counts, seed), true);
annotated.cache();
// adding these lines causes constant time complexity like i want
//List<AnnotatedDocument> ll = annotated.collect();
//annotated = sc.parallelize(ll, 8);
}
So in effect, the line (**) results in an RDD
with the form
documents
.mapPartitionsWithIndex(initial)
.mapPartitionsWithIndex(nextIter)
.mapPartitionsWithIndex(nextIter)
.mapPartitionsWithIndex(nextIter)
... 2000 more
a very long chain of maps indeed. In addition, line (*) forces computation (non-lazy) at each iteration as counts need to be updated.
The problem I have is that I get a time complexity that increases linearly with each iteration, and so quadratic overall:
I think this is because Spark tries to "remember" every RDD in the chain, and the fault tolerance algorithm or whatever is causing this to grow. However, I really have no idea.
What I'd really like to do is at each iteration tell Spark to "collapse" the RDD so that only the last one is kept in memory and worked on. This should result in constant time per iteration, I think. Is this possible? Are there any other solutions?
Thanks!
Try using rdd.checkpoint. This will save RDD to hdfs and clear lineage.
Each time you transform an RDD you grow the lineage and Spark has to track what is available and what has to be re-computed. Processing the DAG is expensive and large DAGs tend to kill performance quite quickly. By "checkpointing" you instruct Spark to compute and save resulting RDD and discard the information of how it got created. This makes it similar to simply saving an RDD and reading it back which minimizes DAG operation.
On a sidenote, since you hit this issue, it is good to know that union
also impacts RDD performance by adding steps
and could also throw a StackOverflowError
due to the way lineage information is . See this post
This link has more details with nice diagrams and the subject is also mentioned in this SO post.
That's a really interesting question and there are a few things to consider.
Fundamentally this is an iterative algorithm, if you look at some of the different iterative machine learning algorithms in Spark you can see some approaches to working with this kind of problem.
The first thing that most of them don't cache on each iteration - rather they have a configurable caching interval. I'd probably start by caching every 10 iterations and seeing how that goes.
The other issue becomes the lineage graph, each mapPartitions
you do is growing the graph a little more. At some point keeping track of that data is going to start to become more and more expensive. checkpoint
allows you to have Spark write the current RDD to persistant storage and discard the lineage information. You could try doing this at some interval like every 20 iterations and seeing how this goes.
The 10 and 20 numbers are just sort of basic starting points, they depend on how slow it is to compute the data for each individual iteration and you can play with them to find the right tuning for your job.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With