Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pivot String column on Pyspark Dataframe

I have a simple dataframe like this:

rdd = sc.parallelize(     [         (0, "A", 223,"201603", "PORT"),          (0, "A", 22,"201602", "PORT"),          (0, "A", 422,"201601", "DOCK"),          (1,"B", 3213,"201602", "DOCK"),          (1,"B", 3213,"201601", "PORT"),          (2,"C", 2321,"201601", "DOCK")     ] ) df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"])  df_data.show()  +---+----+----+------+----+ | id|type|cost|  date|ship| +---+----+----+------+----+ |  0|   A| 223|201603|PORT| |  0|   A|  22|201602|PORT| |  0|   A| 422|201601|DOCK| |  1|   B|3213|201602|DOCK| |  1|   B|3213|201601|PORT| |  2|   C|2321|201601|DOCK| +---+----+----+------+----+ 

and I need to pivot it by date:

df_data.groupby(df_data.id, df_data.type).pivot("date").avg("cost").show()  +---+----+------+------+------+ | id|type|201601|201602|201603| +---+----+------+------+------+ |  2|   C|2321.0|  null|  null| |  0|   A| 422.0|  22.0| 223.0| |  1|   B|3213.0|3213.0|  null| +---+----+------+------+------+ 

Everything works as expected. But now I need to pivot it and get a non-numeric column:

df_data.groupby(df_data.id, df_data.type).pivot("date").avg("ship").show() 

and of course I would get an exception:

AnalysisException: u'"ship" is not a numeric column. Aggregation function can only be applied on a numeric column.;' 

I would like to generate something on the line of

+---+----+------+------+------+ | id|type|201601|201602|201603| +---+----+------+------+------+ |  2|   C|DOCK  |  null|  null| |  0|   A| DOCK |  PORT| DOCK| |  1|   B|DOCK  |PORT  |  null| +---+----+------+------+------+ 

Is that possible with pivot?

like image 557
Ivan Avatar asked May 27 '16 15:05

Ivan


People also ask

How do I pivot a DataFrame PySpark?

PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.

How do you rename a pivot column in PySpark?

1. PySpark withColumnRenamed – To rename DataFrame column name. PySpark has a withColumnRenamed() function on DataFrame to change a column name. This is the most straight forward approach; this function takes two parameters; the first is your existing column name and the second is the new column name you wish for.

What does .collect do in PySpark?

Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.


2 Answers

Assuming that (id |type | date) combinations are unique and your only goal is pivoting and not aggregation you can use first (or any other function not restricted to numeric values):

from pyspark.sql.functions import first  (df_data     .groupby(df_data.id, df_data.type)     .pivot("date")     .agg(first("ship"))     .show())  ## +---+----+------+------+------+ ## | id|type|201601|201602|201603| ## +---+----+------+------+------+ ## |  2|   C|  DOCK|  null|  null| ## |  0|   A|  DOCK|  PORT|  PORT| ## |  1|   B|  PORT|  DOCK|  null| ## +---+----+------+------+------+ 

If these assumptions is not correct you'll have to pre-aggregate your data. For example for the most common ship value:

from pyspark.sql.functions import max, struct  (df_data     .groupby("id", "type", "date", "ship")     .count()     .groupby("id", "type")     .pivot("date")     .agg(max(struct("count", "ship")))     .show())  ## +---+----+--------+--------+--------+ ## | id|type|  201601|  201602|  201603| ## +---+----+--------+--------+--------+ ## |  2|   C|[1,DOCK]|    null|    null| ## |  0|   A|[1,DOCK]|[1,PORT]|[1,PORT]| ## |  1|   B|[1,PORT]|[1,DOCK]|    null| ## +---+----+--------+--------+--------+ 
like image 135
zero323 Avatar answered Oct 02 '22 21:10

zero323


In case, if someone is looking for SQL style approach.

rdd = spark.sparkContext.parallelize(     [         (0, "A", 223,"201603", "PORT"),          (0, "A", 22,"201602", "PORT"),          (0, "A", 422,"201601", "DOCK"),          (1,"B", 3213,"201602", "DOCK"),          (1,"B", 3213,"201601", "PORT"),          (2,"C", 2321,"201601", "DOCK")     ] ) df_data = spark.createDataFrame(rdd, ["id","type", "cost", "date", "ship"]) df_data.createOrReplaceTempView("df") df_data.show()  dt_vals=spark.sql("select collect_set(date) from df").collect()[0][0] ['201601', '201602', '201603']  dt_vals_colstr=",".join(["'" + c + "'" for c in sorted(dt_vals)]) "'201601','201602','201603'" 

Part-1 (Note the f format specifier)

spark.sql(f""" select * from  (select id , type, date, ship from df) pivot ( first(ship) for date in ({dt_vals_colstr}) ) """).show(100,truncate=False)  +---+----+------+------+------+ |id |type|201601|201602|201603| +---+----+------+------+------+ |1  |B   |PORT  |DOCK  |null  | |2  |C   |DOCK  |null  |null  | |0  |A   |DOCK  |PORT  |PORT  | +---+----+------+------+------+ 

Part-2

spark.sql(f""" select * from  (select id , type, date, ship from df) pivot ( case when count(*)=0 then null  else struct(count(*),first(ship)) end for date in ({dt_vals_colstr}) ) """).show(100,truncate=False)  +---+----+---------+---------+---------+ |id |type|201601   |201602   |201603   | +---+----+---------+---------+---------+ |1  |B   |[1, PORT]|[1, DOCK]|null     | |2  |C   |[1, DOCK]|null     |null     | |0  |A   |[1, DOCK]|[1, PORT]|[1, PORT]| +---+----+---------+---------+---------+ 
like image 39
stack0114106 Avatar answered Oct 02 '22 19:10

stack0114106