Slow Performance with Apache Spark Gradient Boosted Tree training runs

I'm experimenting with Gradient Boosted Trees learning algorithm from ML library of Spark 1.4. I'm solving a binary classification problem where my input is ~50,000 samples and ~500,000 features. My goal is to output the definition of the resulting GBT ensemble in human-readable format. My experience so far is that for my problem size adding more resources to the cluster seems to not have an effect on the length of the run. A 10-iteration training run seem to roughly take 13hrs. This isn't acceptable since I'm looking to do 100-300 iteration runs, and the execution time seems to explode with the number of iterations.

My Spark application

This isn't the exact code, but it can be reduced to:

SparkConf sc = new SparkConf().setAppName("GBT Trainer")
            // unlimited max result size for intermediate Map-Reduce ops.
            // Having no limit is probably bad, but I've not had time to find
            // a tighter upper bound and the default value wasn't sufficient.
            .set("spark.driver.maxResultSize", "0");
JavaSparkContext jsc = new JavaSparkContext(sc)

// The input file is encoded in plain-text LIBSVM format ~59GB in size
<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), "s3://somebucket/somekey/plaintext_libsvm_file").toJavaRDD();

BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Somewhat-convoluted code below reads in Parquete-formatted output
// of the GBT model and writes it back out as json.
// There might be cleaner ways of achieving the same, but since output
// size is only a few KB I feel little guilt leaving it as is.

// serialize and output the GBT classifier model the only way that the library allows
String outputPath = "s3://somebucket/somekeyprefex";
model.save(jsc.sc(), outputPath + "/parquet");
// read in the parquet-formatted classifier output as a generic DataFrame object
SQLContext sqlContext = new SQLContext(jsc);
DataFrame outputDataFrame = sqlContext.read().parquet(outputPath + "/parquet"));    
// output DataFrame-formatted classifier model as json           
outputDataFrame.write().format("json").save(outputPath + "/json");


What is the performance bottleneck with my Spark application (or with GBT learning algorithm itself) on input of that size and how can I achieve greater execution parallelism?

I'm still a novice Spark dev, and I'd appreciate any tips on cluster configuration and execution profiling.

More details on the cluster setup

I'm running this app on a AWS EMR cluster (emr-4.0.0, YARN cluster mode) of r3.8xlarge instances (32 cores, 244GB RAM each). I'm using such large instances in order to maximize flexibility of resource allocation. So far I've tried using 1-3 r3.8xlarge instances with a variety of resource allocation schemes between the driver and workers. For example, for a cluster of 1 r3.8xlarge instances I submit the app as follows:

aws emr add-steps --cluster-id $1 --steps Name=$2,\

For a cluster of 3 r3.8xlarge instances I tweak resource allocation:


I don't have a clear idea of how much memory is useful to give to every executor, but I feel that I'm being generous in either case. Looking through Spark UI, I'm not seeing task with input size of more than a few GB. I'm steering on the side of caution when giving the driver process so much memory in order to ensure that it isn't memory starved for any intermediate result-aggregation operations.

I'm trying to keep the number of cores per executor down to 5 as per suggestions in Cloudera's How To Tune Your Spark Jobs series (according to them, more that 5 cores tends to introduce a HDFS IO bottleneck). I'm also making sure that there is enough of spare RAM and CPUs left over for the host OS and Hadoop services.

My findings thus far

My only clue is Spark UI showing very long Scheduling Delay for a number of tasks at the tail-end of execution. I also get the feeling that the stages/tasks timeline shown by Spark UI does not account for all of the time that the job takes to finish. I suspect that the driver application is stuck performing some kind of a lengthy operation either at the end of every training iteration, or at the end of the entire training run.

I've already done a fair bit of research on tuning Spark applications. Most articles will give great suggestions on using RDD operations which reduce intermediate input size or avoid shuffling of data between stages. In my case I'm basically using an "out-of-the-box" algorithm, which is written by ML experts and should already be well tuned in this regard. My own code that outputs GBT model to S3 should take a trivial amount of time to run.

I haven't used MLLibs GBT implemention, but I have used both

LightGBM and XGBoost successfully. I'd highly suggest taking a look at these other libraries.

In general, GBM implementations need to train models iteratively as they consider the loss of the entire ensemble when building the next tree. This makes GBM training inherently bottlenecked and not easily parallelizable (unlike random forests which are trivially parallelizable). I'd expect it to perform better with fewer tasks, but that might not be your whole issue. Since you have so many features 500K, you're going to have very high overhead when calculating the histograms and split points during training. You should reduce the number of features you have, especially since they're much larger than the number of samples which will cause it to overfit.

As for tuning your cluster: You want to minimize data movement, so fewer executors with more memory. 1 executor per ec2 instance, with the number of cores set to whatever the instance provides.

Your data is small enough to fit into ~2 EC2s of that size. Assuming you are using doubles (8 bytes), it comes to 8 * 500000 * 50000 = 200 GB Try loading it all into ram by using .cache() on your dataframe. If you perform an operation over all the rows (like sum) you should force it to load and you can measure how long the IO takes. Once its in ram and cached any other operations over it will be faster.

