I have some data in the following format (either RDD or Spark DataFrame):
from pyspark.sql import SQLContext sqlContext = SQLContext(sc) rdd = sc.parallelize([('X01',41,'US',3), ('X01',41,'UK',1), ('X01',41,'CA',2), ('X02',72,'US',4), ('X02',72,'UK',6), ('X02',72,'CA',7), ('X02',72,'XX',8)]) # convert to a Spark DataFrame schema = StructType([StructField('ID', StringType(), True), StructField('Age', IntegerType(), True), StructField('Country', StringType(), True), StructField('Score', IntegerType(), True)]) df = sqlContext.createDataFrame(rdd, schema)
What I would like to do is to 'reshape' the data, convert certain rows in Country(specifically US, UK and CA) into columns:
ID Age US UK CA 'X01' 41 3 1 2 'X02' 72 4 6 7
Essentially, I need something along the lines of Python's pivot
workflow:
categories = ['US', 'UK', 'CA'] new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', columns = 'Country', values = 'Score')
My dataset is rather large so I can't really collect()
and ingest the data into memory to do the reshaping in Python itself. Is there a way to convert Python's .pivot()
into an invokable function while mapping either an RDD or a Spark DataFrame? Any help would be appreciated!
RDD is slower than both Dataframes and Datasets to perform simple operations like grouping the data. It provides an easy API to perform aggregation operations. It performs aggregation faster than both RDDs and Datasets. Dataset is faster than RDDs but a bit slower than Dataframes.
3.2. RDD – RDD is a distributed collection of data elements spread across many machines in the cluster. RDDs are a set of Java or Scala objects representing data. DataFrame – A DataFrame is a distributed collection of data organized into named columns. It is conceptually equal to a table in a relational database.
Pivot Spark DataFrame Spark SQL provides pivot() function to rotate the data from one column into multiple columns (transpose row to column). It is an aggregation where one of the grouping columns values transposed into individual columns with distinct data.
Since Spark 1.6 you can use pivot
function on GroupedData
and provide aggregate expression.
pivoted = (df .groupBy("ID", "Age") .pivot( "Country", ['US', 'UK', 'CA']) # Optional list of levels .sum("Score")) # alternatively you can use .agg(expr)) pivoted.show() ## +---+---+---+---+---+ ## | ID|Age| US| UK| CA| ## +---+---+---+---+---+ ## |X01| 41| 3| 1| 2| ## |X02| 72| 4| 6| 7| ## +---+---+---+---+---+
Levels can be omitted but if provided can both boost performance and serve as an internal filter.
This method is still relatively slow but certainly beats manual passing data manually between JVM and Python.
First up, this is probably not a good idea, because you are not getting any extra information, but you are binding yourself with a fixed schema (ie you must need to know how many countries you are expecting, and of course, additional country means change in code)
Having said that, this is a SQL problem, which is shown below. But in case you suppose it is not too "software like" (seriously, I have heard this!!), then you can refer the first solution.
Solution 1:
def reshape(t): out = [] out.append(t[0]) out.append(t[1]) for v in brc.value: if t[2] == v: out.append(t[3]) else: out.append(0) return (out[0],out[1]),(out[2],out[3],out[4],out[5]) def cntryFilter(t): if t[2] in brc.value: return t else: pass def addtup(t1,t2): j=() for k,v in enumerate(t1): j=j+(t1[k]+t2[k],) return j def seq(tIntrm,tNext): return addtup(tIntrm,tNext) def comb(tP,tF): return addtup(tP,tF) countries = ['CA', 'UK', 'US', 'XX'] brc = sc.broadcast(countries) reshaped = calls.filter(cntryFilter).map(reshape) pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1) for i in pivot.collect(): print i
Now, Solution 2: Of course better as SQL is right tool for this
callRow = calls.map(lambda t: Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3])) callsDF = ssc.createDataFrame(callRow) callsDF.printSchema() callsDF.registerTempTable("calls") res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\ from (select userid,age,\ case when country='CA' then nbrCalls else 0 end ca,\ case when country='UK' then nbrCalls else 0 end uk,\ case when country='US' then nbrCalls else 0 end us,\ case when country='XX' then nbrCalls else 0 end xx \ from calls) x \ group by userid,age") res.show()
data set up:
data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)] calls = sc.parallelize(data,1) countries = ['CA', 'UK', 'US', 'XX']
Result:
From 1st solution
(('X02', 72), (7, 6, 4, 8)) (('X01', 41), (2, 1, 3, 0))
From 2nd solution:
root |-- age: long (nullable = true) |-- country: string (nullable = true) |-- nbrCalls: long (nullable = true) |-- userid: string (nullable = true) userid age ca uk us xx X02 72 7 6 4 8 X01 41 2 1 3 0
Kindly let me know if this works, or not :)
Best Ayan
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With