Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark iteration time increasing exponentially when using join

I'm quite new to Spark and I'm trying to implement some iterative algorithm for clustering (expectation-maximization) with centroid represented by Markov model. So I need to do iterations and joins.

One problem that I experience is that each iterations time growth exponentially.
After some experimenting I found that when doing iterations it's needed to persist RDD that is going to be reused in the next iteration, otherwise every iteration spark will create execution plan that will recalculate the RDD from from start, thus increasing calculation time.

init = sc.parallelize(xrange(10000000), 3)
init.cache()

for i in range(6):
    print i
    start = datetime.datetime.now()

    init2 = init.map(lambda n: (n, n*3))        
    init = init2.map(lambda n: n[0])
#     init.cache()

    print init.count()    
    print str(datetime.datetime.now() - start)

Results in:

0
10000000
0:00:04.283652
1
10000000
0:00:05.998830
2
10000000
0:00:08.771984
3
10000000
0:00:11.399581
4
10000000
0:00:14.206069
5
10000000
0:00:16.856993

So adding cache() helps and iteration time become constant.

init = sc.parallelize(xrange(10000000), 3)
init.cache()

for i in range(6):
    print i
    start = datetime.datetime.now()

    init2 = init.map(lambda n: (n, n*3))        
    init = init2.map(lambda n: n[0])
    init.cache()

    print init.count()    
    print str(datetime.datetime.now() - start)
0
10000000
0:00:04.966835
1
10000000
0:00:04.609885
2
10000000
0:00:04.324358
3
10000000
0:00:04.248709
4
10000000
0:00:04.218724
5
10000000
0:00:04.223368

But when making Join inside the iteration the problem comes back. Here is some simple code I demonstrating the problem. Even making cache on each RDD transformation doesn't solve the problem:

init = sc.parallelize(xrange(10000), 3)
init.cache()

for i in range(6):
    print i
    start = datetime.datetime.now()

    init2 = init.map(lambda n: (n, n*3))
    init2.cache()

    init3 = init.map(lambda n: (n, n*2))
    init3.cache()

    init4 = init2.join(init3)
    init4.count()
    init4.cache()

    init = init4.map(lambda n: n[0])
    init.cache()

    print init.count()    
    print str(datetime.datetime.now() - start)

And here is the output. As you can see iteration time growing exponentially :(

0
10000
0:00:00.674115
1
10000
0:00:00.833377
2
10000
0:00:01.525314
3
10000
0:00:04.194715
4
10000
0:00:08.139040
5
10000
0:00:17.852815

I will really appreciate any help :)

like image 327
sashaostr Avatar asked Jul 27 '15 17:07

sashaostr


2 Answers

Summary:

Generally speaking iterative algorithms, especially ones with self-join or self-union, require a control over:

  • Length of the lineage (see for example Stackoverflow due to long RDD Lineage and unionAll resulting in StackOverflow).
  • Number of partitions.

Problem described here is a result of the lack of the former one. In each iteration number of partition increases with self-join leading to exponential pattern. To address that you have to either control number of partitions in each iteration (see below) or use global tools like spark.default.parallelism (see an answer provided by Travis). In general the first approach provides much more control in general and doesn't affect other parts of code.

Original answer:

As far as I can tell there are two interleaved problems here - growing number of partitions and shuffling overhead during joins. Both can be easily handled so lets go step by step.

First lets create a helper to collect the statistics:

import datetime

def get_stats(i, init, init2, init3, init4,
       start, end, desc, cache, part, hashp):
    return {
        "i": i,
        "init": init.getNumPartitions(),
        "init1": init2.getNumPartitions(),
        "init2": init3.getNumPartitions(),
        "init4": init4.getNumPartitions(),
        "time": str(end - start),
        "timen": (end - start).seconds + (end - start).microseconds * 10 **-6,
        "desc": desc,
        "cache": cache,
        "part": part,
        "hashp": hashp
    }

another helper to handle caching/partitioning

def procRDD(rdd, cache=True, part=False, hashp=False, npart=16):
    rdd = rdd if not part else rdd.repartition(npart)
    rdd = rdd if not hashp else rdd.partitionBy(npart)
    return rdd if not cache else rdd.cache()

extract pipeline logic:

def run(init, description, cache=True, part=False, hashp=False, 
    npart=16, n=6):
    times = []

    for i in range(n):
        start = datetime.datetime.now()

        init2 = procRDD(
                init.map(lambda n: (n, n*3)),
                cache, part, hashp, npart)
        init3 = procRDD(
                init.map(lambda n: (n, n*2)),
                cache, part, hashp, npart)


        # If part set to True limit number of the output partitions
        init4 = init2.join(init3, npart) if part else init2.join(init3) 
        init = init4.map(lambda n: n[0])

        if cache:
            init4.cache()
            init.cache()

        init.count() # Force computations to get time
        end = datetime.datetime.now() 

        times.append(get_stats(
            i, init, init2, init3, init4,
            start, end, description,
            cache, part, hashp
        ))

    return times

and create initial data:

ncores = 8
init = sc.parallelize(xrange(10000), ncores * 2).cache()

Join operation by itself, if numPartitions argument is not provided, adjust number of partitions in the output based on the number of partitions of the input RDDs. It means growing number of partitions with each iteration. If number of partitions is to large things get ugly. You can deal with these by providing numPartitions argument for join or repartition RDDs with each iteration.

timesCachePart = sqlContext.createDataFrame(
        run(init, "cache + partition", True, True, False, ncores * 2))
timesCachePart.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+-----------------+
|i|init1|init2|init4|          time|             desc|
+-+-----+-----+-----+--------------+-----------------+
|0|   16|   16|   16|0:00:01.145625|cache + partition|
|1|   16|   16|   16|0:00:01.090468|cache + partition|
|2|   16|   16|   16|0:00:01.059316|cache + partition|
|3|   16|   16|   16|0:00:01.029544|cache + partition|
|4|   16|   16|   16|0:00:01.033493|cache + partition|
|5|   16|   16|   16|0:00:01.007598|cache + partition|
+-+-----+-----+-----+--------------+-----------------+

As you can see when we repartition execution time is more or less constant. The second problem is that above data is partitioned randomly. To ensure join performance we would like to have same keys on a single partition. To achieve that we can use hash partitioner:

timesCacheHashPart = sqlContext.createDataFrame(
    run(init, "cache + hashpart", True, True, True, ncores * 2))
timesCacheHashPart.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+----------------+
|i|init1|init2|init4|          time|            desc|
+-+-----+-----+-----+--------------+----------------+
|0|   16|   16|   16|0:00:00.946379|cache + hashpart|
|1|   16|   16|   16|0:00:00.966519|cache + hashpart|
|2|   16|   16|   16|0:00:00.945501|cache + hashpart|
|3|   16|   16|   16|0:00:00.986777|cache + hashpart|
|4|   16|   16|   16|0:00:00.960989|cache + hashpart|
|5|   16|   16|   16|0:00:01.026648|cache + hashpart|
+-+-----+-----+-----+--------------+----------------+

Execution time is constant as before and There is a small improvement over the basic partitioning.

Now lets use cache only as a reference:

timesCacheOnly = sqlContext.createDataFrame(
    run(init, "cache-only", True, False, False, ncores * 2))
timesCacheOnly.select("i", "init1", "init2", "init4", "time", "desc").show()


+-+-----+-----+-----+--------------+----------+
|i|init1|init2|init4|          time|      desc|
+-+-----+-----+-----+--------------+----------+
|0|   16|   16|   32|0:00:00.992865|cache-only|
|1|   32|   32|   64|0:00:01.766940|cache-only|
|2|   64|   64|  128|0:00:03.675924|cache-only|
|3|  128|  128|  256|0:00:06.477492|cache-only|
|4|  256|  256|  512|0:00:11.929242|cache-only|
|5|  512|  512| 1024|0:00:23.284508|cache-only|
+-+-----+-----+-----+--------------+----------+

As you can see number of partitions (init2, init3, init4) for cache-only version doubles with each iteration and execution time is proportional to the number of partitions.

Finally we can check if we can improve performance with large number of partitions if we use hash partitioner:

timesCacheHashPart512 = sqlContext.createDataFrame(
    run(init, "cache + hashpart 512", True, True, True, 512))
timesCacheHashPart512.select(
    "i", "init1", "init2", "init4", "time", "desc").show()
+-+-----+-----+-----+--------------+--------------------+
|i|init1|init2|init4|          time|                desc|
+-+-----+-----+-----+--------------+--------------------+
|0|  512|  512|  512|0:00:14.492690|cache + hashpart 512|
|1|  512|  512|  512|0:00:20.215408|cache + hashpart 512|
|2|  512|  512|  512|0:00:20.408070|cache + hashpart 512|
|3|  512|  512|  512|0:00:20.390267|cache + hashpart 512|
|4|  512|  512|  512|0:00:20.362354|cache + hashpart 512|
|5|  512|  512|  512|0:00:19.878525|cache + hashpart 512|
+-+-----+-----+-----+--------------+--------------------+

Improvement is not so impressive but if you have a small cluster and a lot of data it is still worth trying.

I guess take away message here is partitioning matters. There are contexts where it is handled for you (mllib, sql) but if you use low level operations it is your responsibility.

like image 127
zero323 Avatar answered Nov 09 '22 07:11

zero323


The problem is (as zero323 pointed out in his thorough answer) that calling join without specifying the number of partitions may (does) result in a growing number of partitions. The number of partitions can grow (apparently) without bound. There are (at least) two ways to prevent the number of partitions from growing (without bound) when repeatedly calling join.

Method 1:

As zero323 pointed out, you can specify the number of partitions manually when you call join. For example

rdd1.join(rdd2, numPartitions)

This will ensure that the number of Partitions does not exceed numPartitions and in particular the number of partitions will not continually grow.

Method 2:

When you create your SparkConf you can specify the default level of parallelism. If this value is set, then when you call functions like join without specifying numPartitions, the default parallelism will be used instead, effectively capping the number of partitions and preventing them from growing. You can set this parameter as

conf=SparkConf.set("spark.default.parallelism", numPartitions)
sc = SparkContex(conf=conf)   
like image 8
TravisJ Avatar answered Nov 09 '22 08:11

TravisJ