Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

DataFrame / Dataset groupBy behaviour/optimization

Suppose we have DataFrame df consisting of the following columns:

Name, Surname, Size, Width, Length, Weigh

Now we want to perform a couple of operations, for example we want to create a couple of DataFrames containing data about Size and Width.

val df1 = df.groupBy("surname").agg( sum("size") )
val df2 = df.groupBy("surname").agg( sum("width") )

as you can notice, other columns, like Length are not used anywhere. Is Spark smart enough to drop the redundant columns before the shuffling phase or are they carried around? Wil running:

val dfBasic = df.select("surname", "size", "width")

before grouping somehow affect the performance?

like image 910
TheMP Avatar asked Oct 02 '15 08:10

TheMP


1 Answers

Yes, it is "smart enough". groupBy performed on a DataFrame is not the same operation as groupBy performed on a plain RDD. In a scenario you've described there is no need to move raw data at all. Let's create a small example to illustrate that:

val df = sc.parallelize(Seq(
   ("a", "foo", 1), ("a", "foo", 3), ("b", "bar", 5), ("b", "bar", 1)
)).toDF("x", "y", "z")

df.groupBy("x").agg(sum($"z")).explain

// == Physical Plan ==
// *HashAggregate(keys=[x#148], functions=[sum(cast(z#150 as bigint))])
// +- Exchange hashpartitioning(x#148, 200)
//    +- *HashAggregate(keys=[x#148], functions=[partial_sum(cast(z#150 as bigint))])
//       +- *Project [_1#144 AS x#148, _3#146 AS z#150]
//          +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#144, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#145, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#146]
//             +- Scan ExternalRDDScan[obj#143]

As you can the first phase is a projection where only required columns are preserved. Next data is aggregated locally and finally transferred and aggregated globally. You'll get a little bit different answer output if you use Spark <= 1.4 but general structure should be exactly the same.

Finally a DAG visualization showing that above description describes actual job:

group by and agg DAG

Similarly, Dataset.groupByKey followed by reduceGroups, contains both map-side (ObjectHashAggregate with partial_reduceaggregator) and reduce-side (ObjectHashAggregate with reduceaggregator reduction):

case class Foo(x: String, y: String, z: Int)

val ds = df.as[Foo]
ds.groupByKey(_.x).reduceGroups((x, y) => x.copy(z = x.z + y.z)).explain

// == Physical Plan ==
// ObjectHashAggregate(keys=[value#126], functions=[reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
// +- Exchange hashpartitioning(value#126, 200)
//    +- ObjectHashAggregate(keys=[value#126], functions=[partial_reduceaggregator(org.apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
//       +- AppendColumns <function1>, newInstance(class $line40.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#126]
//          +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//             +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                +- Scan ExternalRDDScan[obj#3]

groupByKey + reduceGroups

However other methods of KeyValueGroupedDataset might work similarly to RDD.groupByKey. For example mapGroups (or flatMapGroups) doesn't use partial aggregation.

ds.groupByKey(_.x)
  .mapGroups((_, iter) => iter.reduce((x, y) => x.copy(z = x.z + y.z)))
  .explain

//== Physical Plan ==
//*SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).x, true, false) AS x#37, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).y, true, false) AS y#38, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).z AS z#39]
//+- MapGroups <function2>, value#32.toString, newInstance(class $line15.$read$$iw$$iw$Foo), [value#32], [x#8, y#9, z#10], obj#36: $line15.$read$$iw$$iw$Foo
//   +- *Sort [value#32 ASC NULLS FIRST], false, 0
//      +- Exchange hashpartitioning(value#32, 200)
//         +- AppendColumns <function1>, newInstance(class $line15.$read$$iw$$iw$Foo), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#32]
//            +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//               +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                  +- Scan ExternalRDDScan[obj#3]

groupByKey + mapGroups

like image 72
zero323 Avatar answered Nov 12 '22 05:11

zero323