Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to flatten columns of type array of structs (as returned by Spark ML API)?

Maybe it's just because I'm relatively new to the API, but I feel like Spark ML methods often return DFs that are unnecessarily difficult to work with.

This time, it's the ALS model that's tripping me up. Specifically, the recommendForAllUsers method. Let's reconstruct the type of DF it would return:

scala> val arrayType = ArrayType(new StructType().add("itemId", IntegerType).add("rating", FloatType))

scala> val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations".cast(arrayType))

scala> recs.show()
+------+------------------+
|userId|   recommendations|
+------+------------------+
|     1|[[1,0.7], [2,0.5]]|
|     2|[[0,0.9], [4,0.1]]|
+------+------------------+
scala> recs.printSchema
root
 |-- userId: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- itemId: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)

Now, I only care about the itemId in the recommendations column. After all, the method is recommendForAllUsers not recommendAndScoreForAllUsers (ok ok I'll stop being sassy...)

How do I do this??

I thought I had it when I created a UDF:

scala> val itemIds = udf((arr: Array[(Int, Float)]) => arr.map(_._1))

but that produces an error:

scala> recs.withColumn("items", items($"recommendations"))
org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(recommendations)' due to data type mismatch: argument 1 requires array<struct<_1:int,_2:float>> type, however, '`recommendations`' is of array<struct<itemId:int,rating:float>> type.;;
'Project [userId#87, recommendations#92, UDF(recommendations#92) AS items#238]
+- Project [userId#87, cast(recommendations#88 as array<struct<itemId:int,rating:float>>) AS recommendations#92]
   +- Project [_1#84 AS userId#87, _2#85 AS recommendations#88]
      +- LocalRelation [_1#84, _2#85]

Any ideas? thanks!

like image 321
Nick Resnick Avatar asked Oct 13 '17 18:10

Nick Resnick


People also ask

How do you flatten an array in Pyspark?

If you want to flatten the arrays, use flatten function which converts array of array columns to a single array on DataFrame.

How do you explode an array of struct in spark?

Solution: Spark explode function can be used to explode an Array of Struct ArrayType(StructType) columns to rows on Spark DataFrame using scala example. Before we start, let's create a DataFrame with Struct column in an array.


2 Answers

wow, my coworker came up with an extremely elegant solution:

scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
|     1|[1, 2]|
|     2|[0, 4]|
+------+------+

So maybe the Spark ML API isn't that difficult after all :)

like image 62
Nick Resnick Avatar answered Oct 25 '22 13:10

Nick Resnick


With an array as the type of a column, e.g. recommendations, you'd be quite productive using explode function (or the more advanced flatMap operator).

explode(e: Column): Column Creates a new row for each element in the given array or map column.

That gives you bare structs to work with.

import org.apache.spark.sql.types._
val structType = new StructType().
  add($"itemId".int).
  add($"rating".float)
val arrayType = ArrayType(structType)
val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations" cast arrayType)

val exploded = recs.withColumn("recs", explode($"recommendations"))
scala> exploded.show
+------+------------------+-------+
|userId|   recommendations|   recs|
+------+------------------+-------+
|     1|[[1,0.7], [2,0.5]]|[1,0.7]|
|     1|[[1,0.7], [2,0.5]]|[2,0.5]|
|     2|[[0,0.9], [4,0.1]]|[0,0.9]|
|     2|[[0,0.9], [4,0.1]]|[4,0.1]|
+------+------------------+-------+

structs are nice in select operator with * (star) to flatten them to columns per the struct fields.

You could do select($"element.*").

scala> exploded.select("userId", "recs.*").show
+------+------+------+
|userId|itemId|rating|
+------+------+------+
|     1|     1|   0.7|
|     1|     2|   0.5|
|     2|     0|   0.9|
|     2|     4|   0.1|
+------+------+------+

I think that could do what you're after.


p.s. Stay away from UDFs as long as possible since they "trigger" row conversion from the internal format (InternalRow) to JVM objects that can lead to excessive GCs.

like image 28
Jacek Laskowski Avatar answered Oct 25 '22 14:10

Jacek Laskowski