I have seen a lot of performance improvement in my pyspark code when I replaced distinct()
on a spark data frame with groupBy()
. But I failed to understand the reason behind it.
The whole intention was to remove the row level duplicates from the dataframe.
I tried Googling the implementation of groupBy()
and distinct()
in pyspark, but was unable to find it.
Can somebody explain or point me in the right direction for the explanation?
I've recently focused on the difference between the GROUP BY
and DISTINCT
operations in Apache Spark SQL. It happens that...both can sometimes be the same!
To see this, run the following code and check the execution plans:
(0 to 10).map(id => (s"id#${id}", s"login${id % 25}"))
.toDF("id", "login").createTempView("users")
sparkSession.sql("SELECT login FROM users GROUP BY login").explain(true)
sparkSession.sql("SELECT DISTINCT(login) FROM users").explain(true)
Surprise, surprise! The plans should look like that:
== Physical Plan ==
*(2) HashAggregate(keys=[login#8], functions=[], output=[login#8])
+- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#33]
+- *(1) HashAggregate(keys=[login#8], functions=[], output=[login#8])
+- *(1) LocalTableScan [login#8]
Why? Because of the ReplaceDistinctWithAggregate rule that you should see in action in the logs:
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate ===
!Distinct Aggregate [login#8], [login#8]
+- LocalRelation [login#8] +- LocalRelation [login#8]
(org.apache.spark.sql.catalyst.rules.PlanChangeLogger:65)
=========================== Update:
For more complex queries (e.g. with aggregates), it can be a difference.
sparkSession.sql("SELECT COUNT(login) FROM users GROUP BY login").explain(true)
sparkSession.sql("SELECT COUNT(DISTINCT(login)) FROM users").explain(true)
The GROUP BY
version generates a plan with only one shuffle:
== Physical Plan ==
*(2) HashAggregate(keys=[login#8], functions=[count(login#8)], output=[count(login)#12L])
+- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#16]
+- *(1) HashAggregate(keys=[login#8], functions=[partial_count(login#8)], output=[login#8, count#15L])
+- *(1) LocalTableScan [login#8]
Whereas the version with DISTINCT
generates 2 shuffles. The first is there to deduplicate the logins and the second to count them:
== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(distinct login#8)], output=[count(DISTINCT login)#17L])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#48]
+- *(2) HashAggregate(keys=[], functions=[partial_count(distinct login#8)], output=[count#21L])
+- *(2) HashAggregate(keys=[login#8], functions=[], output=[login#8])
+- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#43]
+- *(1) HashAggregate(keys=[login#8], functions=[], output=[login#8])
+- *(1) LocalTableScan [login#8]
However, semantically these queries are not the same because the first generates the login groups whereas the second also counts them. And it explains the extra shuffle step.
It could be easier to answer the question with the code before/after the change. @pri, do you have it so that we can analyze the plans executed by PySpark?
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With