I have a simple dataframe like this:
rdd = sc.parallelize( [ (0, "A", 223,"201603", "PORT"), (0, "A", 22,"201602", "PORT"), (0, "A", 422,"201601", "DOCK"), (1,"B", 3213,"201602", "DOCK"), (1,"B", 3213,"201601", "PORT"), (2,"C", 2321,"201601", "DOCK") ] ) df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"]) df_data.show() +---+----+----+------+----+ | id|type|cost| date|ship| +---+----+----+------+----+ | 0| A| 223|201603|PORT| | 0| A| 22|201602|PORT| | 0| A| 422|201601|DOCK| | 1| B|3213|201602|DOCK| | 1| B|3213|201601|PORT| | 2| C|2321|201601|DOCK| +---+----+----+------+----+
and I need to pivot it by date:
df_data.groupby(df_data.id, df_data.type).pivot("date").avg("cost").show() +---+----+------+------+------+ | id|type|201601|201602|201603| +---+----+------+------+------+ | 2| C|2321.0| null| null| | 0| A| 422.0| 22.0| 223.0| | 1| B|3213.0|3213.0| null| +---+----+------+------+------+
Everything works as expected. But now I need to pivot it and get a non-numeric column:
df_data.groupby(df_data.id, df_data.type).pivot("date").avg("ship").show()
and of course I would get an exception:
AnalysisException: u'"ship" is not a numeric column. Aggregation function can only be applied on a numeric column.;'
I would like to generate something on the line of
+---+----+------+------+------+ | id|type|201601|201602|201603| +---+----+------+------+------+ | 2| C|DOCK | null| null| | 0| A| DOCK | PORT| DOCK| | 1| B|DOCK |PORT | null| +---+----+------+------+------+
Is that possible with pivot
?
PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.
1. PySpark withColumnRenamed – To rename DataFrame column name. PySpark has a withColumnRenamed() function on DataFrame to change a column name. This is the most straight forward approach; this function takes two parameters; the first is your existing column name and the second is the new column name you wish for.
Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.
Assuming that (id |type | date)
combinations are unique and your only goal is pivoting and not aggregation you can use first
(or any other function not restricted to numeric values):
from pyspark.sql.functions import first (df_data .groupby(df_data.id, df_data.type) .pivot("date") .agg(first("ship")) .show()) ## +---+----+------+------+------+ ## | id|type|201601|201602|201603| ## +---+----+------+------+------+ ## | 2| C| DOCK| null| null| ## | 0| A| DOCK| PORT| PORT| ## | 1| B| PORT| DOCK| null| ## +---+----+------+------+------+
If these assumptions is not correct you'll have to pre-aggregate your data. For example for the most common ship
value:
from pyspark.sql.functions import max, struct (df_data .groupby("id", "type", "date", "ship") .count() .groupby("id", "type") .pivot("date") .agg(max(struct("count", "ship"))) .show()) ## +---+----+--------+--------+--------+ ## | id|type| 201601| 201602| 201603| ## +---+----+--------+--------+--------+ ## | 2| C|[1,DOCK]| null| null| ## | 0| A|[1,DOCK]|[1,PORT]|[1,PORT]| ## | 1| B|[1,PORT]|[1,DOCK]| null| ## +---+----+--------+--------+--------+
In case, if someone is looking for SQL style approach.
rdd = spark.sparkContext.parallelize( [ (0, "A", 223,"201603", "PORT"), (0, "A", 22,"201602", "PORT"), (0, "A", 422,"201601", "DOCK"), (1,"B", 3213,"201602", "DOCK"), (1,"B", 3213,"201601", "PORT"), (2,"C", 2321,"201601", "DOCK") ] ) df_data = spark.createDataFrame(rdd, ["id","type", "cost", "date", "ship"]) df_data.createOrReplaceTempView("df") df_data.show() dt_vals=spark.sql("select collect_set(date) from df").collect()[0][0] ['201601', '201602', '201603'] dt_vals_colstr=",".join(["'" + c + "'" for c in sorted(dt_vals)]) "'201601','201602','201603'"
Part-1 (Note the f
format specifier)
spark.sql(f""" select * from (select id , type, date, ship from df) pivot ( first(ship) for date in ({dt_vals_colstr}) ) """).show(100,truncate=False) +---+----+------+------+------+ |id |type|201601|201602|201603| +---+----+------+------+------+ |1 |B |PORT |DOCK |null | |2 |C |DOCK |null |null | |0 |A |DOCK |PORT |PORT | +---+----+------+------+------+
Part-2
spark.sql(f""" select * from (select id , type, date, ship from df) pivot ( case when count(*)=0 then null else struct(count(*),first(ship)) end for date in ({dt_vals_colstr}) ) """).show(100,truncate=False) +---+----+---------+---------+---------+ |id |type|201601 |201602 |201603 | +---+----+---------+---------+---------+ |1 |B |[1, PORT]|[1, DOCK]|null | |2 |C |[1, DOCK]|null |null | |0 |A |[1, DOCK]|[1, PORT]|[1, PORT]| +---+----+---------+---------+---------+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With