Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Does spark's distinct() function shuffle only the distinct tuples from each partition

As I understand distinct() hash partitions the RDD to identify the unique keys. But does it optimize on moving only the distinct tuples per partition?

Imagine an RDD with the following partitions

  1. [1, 2, 2, 1, 4, 2, 2]
  2. [1, 3, 3, 5, 4, 5, 5, 5]

On a distinct on this RDD, would all the duplicate keys (2s in partition 1 and 5s in partition 2) get shuffled to their target partition or will only the distinct keys per partition get shuffled to the target?

If all keys get shuffled then an aggregate() with set() operations will reduce the shuffle.

def set_update(u, v):
    u.add(v)
    return u
rdd.aggregate(set(), set_update, lambda u1,u2: u1|u2)
like image 970
zonked.zonda Avatar asked Mar 13 '23 20:03

zonked.zonda


1 Answers

unique is implemented via reduceByKey on (element, None) pairs. So it shuffles only unique values per partition. If number of duplicates is low it is still quite expensive operation though.

There are situations when using set can be useful. In particular if you call distinct on PairwseRDD you may prefer to aggregateByKey / combineByKey instead to achieve both deduplication and partitioning by key at the same time. In particular consider following code:

rdd1 = sc.parallelize([("foo", 1), ("foo", 1), ("bar", 1)])
rdd2 = sc.parallelize([("foo", "x"), ("bar", "y")])
rdd1.distinct().join(rdd2)

It has to shuffle rdd1 twice - once for distinct and once for join. Instead you can use combineByKey:

def flatten(kvs):
    (key, (left, right)) = kvs
    for v in left:
        yield (key, (v, right))

aggregated = (rdd1
    .aggregateByKey(set(), set_update, lambda u1, u2: u1 | u2))

rdd2_partitioned = rdd2.partitionBy(aggregated.getNumPartitions())

(aggregated.join(rdd2_partitioned)
    .flatMap(flatten))

Note:

join logic is a little bit different in Scala than in Python (PySpark is using union followed by groupByKey, see Spark RDD groupByKey + join vs join performance for Python and Scala DAGs), hence we have to manually partition the second RDD before we call join.

like image 115
zero323 Avatar answered Mar 16 '23 11:03

zero323