As I understand distinct() hash partitions the RDD to identify the unique keys. But does it optimize on moving only the distinct tuples per partition?
Imagine an RDD with the following partitions
On a distinct on this RDD, would all the duplicate keys (2s in partition 1 and 5s in partition 2) get shuffled to their target partition or will only the distinct keys per partition get shuffled to the target?
If all keys get shuffled then an aggregate() with set() operations will reduce the shuffle.
def set_update(u, v):
u.add(v)
return u
rdd.aggregate(set(), set_update, lambda u1,u2: u1|u2)
unique
is implemented via reduceByKey
on (element, None)
pairs. So it shuffles only unique values per partition. If number of duplicates is low it is still quite expensive operation though.
There are situations when using set
can be useful. In particular if you call distinct
on PairwseRDD
you may prefer to aggregateByKey
/ combineByKey
instead to achieve both deduplication and partitioning by key at the same time. In particular consider following code:
rdd1 = sc.parallelize([("foo", 1), ("foo", 1), ("bar", 1)])
rdd2 = sc.parallelize([("foo", "x"), ("bar", "y")])
rdd1.distinct().join(rdd2)
It has to shuffle rdd1
twice - once for distinct
and once for join
. Instead you can use combineByKey
:
def flatten(kvs):
(key, (left, right)) = kvs
for v in left:
yield (key, (v, right))
aggregated = (rdd1
.aggregateByKey(set(), set_update, lambda u1, u2: u1 | u2))
rdd2_partitioned = rdd2.partitionBy(aggregated.getNumPartitions())
(aggregated.join(rdd2_partitioned)
.flatMap(flatten))
Note:
join
logic is a little bit different in Scala than in Python (PySpark is using union
followed by groupByKey
, see Spark RDD groupByKey + join vs join performance for Python and Scala DAGs), hence we have to manually partition the second RDD
before we call join.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With