Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pyspark matrix with dummy variables

Have two columns:

ID  Text
1    a
2    b
3    c

How can I able to create matrix with dummy variables like this:

ID a b c
1  1 0 0
2  0 1 0
3  0 0 1

Using pyspark library and its features?

like image 268
Keithx Avatar asked Mar 08 '16 22:03

Keithx


1 Answers

An alternative solution is to use Spark’s pivot method, which has been around since Spark 1.6.0.

Example:

from pyspark.sql import functions as F

df = sqlContext.createDataFrame([
    (1, "a"),
    (2, "b"),
    (3, "c"),],
    ["ID", "Text"])    

pivoted = df.groupBy("ID").pivot("Text").agg(F.lit(1))
pivoted.show()
# +---+----+----+----+
# | ID|   a|   b|   c|
# +---+----+----+----+
# |  1|   1|null|null|
# |  3|null|null|   1|
# |  2|null|   1|null|
# +---+----+----+----+

To get rid of the missing values, simply use the na methods:

pivoted.na.fill(0).show()
# +---+---+---+---+
# | ID|  a|  b|  c|
# +---+---+---+---+
# |  1|  1|  0|  0|
# |  3|  0|  0|  1|
# |  2|  0|  1|  0|
# +---+---+---+---+

Pivoting is more general than the solution proposed by ksindi, as it can aggregate numbers. That being said, the solution proposed by ksindi is more efficient in this particular case as it only requires one pass over the data, two if you take into account the pass to get the categories. For pivoting, you can also add the categories as the second positional parameter to pivot which improves efficiency. But the groupBy call will already cause a shuffle which makes this approach slower.

Note: the groupBy call silently assumes that the ID column in the example contains unique values to get to the desired output. Had the example dataframe looked like:

df = sqlContext.createDataFrame([
    (1, "a"),
    (2, "b"),
    (3, "c"),
    (3, "a")],
    ["ID", "Text"]) 

the outcome of this solution would've been

df.groupBy("ID").pivot("Text").agg(F.lit(1)).na.fill(0).show()
# +---+---+---+---+
# | ID|  a|  b|  c|
# +---+---+---+---+
# |  1|  1|  0|  0|
# |  3|  1|  0|  1|
# |  2|  0|  1|  0|
# +---+---+---+---+

Whereas the mapping solution would end up as

df.select("ID", *exprs).show()
# +---+---+---+---+
# | ID|  c|  b|  a|
# +---+---+---+---+
# |  1|  0|  0|  1|
# |  2|  0|  1|  0|
# |  3|  1|  0|  0|
# |  3|  0|  0|  1|
# +---+---+---+---+
like image 90
Oliver W. Avatar answered Oct 13 '22 21:10

Oliver W.