Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark converting a column of type 'map' to multiple columns in a dataframe

Input

I have a column Parameters of type map of the form:

>>> from pyspark.sql import SQLContext
>>> sqlContext = SQLContext(sc)
>>> d = [{'Parameters': {'foo': '1', 'bar': '2', 'baz': 'aaa'}}]
>>> df = sqlContext.createDataFrame(d)
>>> df.collect()
[Row(Parameters={'foo': '1', 'bar': '2', 'baz': 'aaa'})]

Output

I want to reshape it in pyspark so that all the keys (foo, bar, etc.) are columns, namely:

[Row(foo='1', bar='2', baz='aaa')]

Using withColumn works:

(df
 .withColumn('foo', df.Parameters['foo'])
 .withColumn('bar', df.Parameters['bar'])
 .withColumn('baz', df.Parameters['baz'])
 .drop('Parameters')
).collect()

But I need like a solution that doesn't explicitly mention the column names as I have dozens of them.

Schema

>>> df.printSchema()

root
 |-- Parameters: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)
like image 472
Kamil Sindi Avatar asked Apr 26 '16 15:04

Kamil Sindi


People also ask

How do I change dataType of multiple columns in PySpark DataFrame?

Method 1: Using DataFrame.withColumn() We will make use of cast(x, dataType) method to casts the column to a different data type. Here, the parameter “x” is the column name and dataType is the datatype in which you want to change the respective column to.

How do I map a column in PySpark?

Solution: PySpark SQL function create_map() is used to convert selected DataFrame columns to MapType , create_map() takes a list of columns you wanted to convert as an argument and returns a MapType column.

How do I change the Dtype of a column in PySpark?

withColumn() – Change Column Type Use withColumn() to convert the data type of a DataFrame column, This function takes column name you wanted to convert as a first argument and for the second argument apply the casting method cast() with DataType on the column.

How do you split a column in PySpark DataFrame?

The PySpark SQL provides the split() function to convert delimiter separated String to an Array (StringType to ArrayType) column on DataFrame It can be done by splitting the string column on the delimiter like space, comma, pipe, etc. and converting it into ArrayType.


2 Answers

Since keys of the MapType are not a part of the schema you'll have to collect these first for example like this:

from pyspark.sql.functions import explode

keys = (df
    .select(explode("Parameters"))
    .select("key")
    .distinct()
    .rdd.flatMap(lambda x: x)
    .collect())

When you have this all what is left is simple select:

from pyspark.sql.functions import col

exprs = [col("Parameters").getItem(k).alias(k) for k in keys]
df.select(*exprs)
like image 95
zero323 Avatar answered Oct 03 '22 23:10

zero323


Performant solution

One of the question constraints is to dynamically determine the column names, which is fine, but be warned that this can be really slow. Here's how you can avoid typing and write code that'll execute quickly.

cols = list(map(
    lambda f: F.col("Parameters").getItem(f).alias(str(f)),
    ["foo", "bar", "baz"]))
df.select(cols).show()
+---+---+---+
|foo|bar|baz|
+---+---+---+
|  1|  2|aaa|
+---+---+---+

Notice that this runs a single select operation. Don't run withColumn multiple times because that's slower.

The fast solution is only possible if you know all the map keys. You'll need to revert to the slower solution if you don't know all the unique values for the map keys.

Slower solution

The accepted answer is good. My solution is a bit more performant because it doesn't call .rdd or flatMap().

import pyspark.sql.functions as F

d = [{'Parameters': {'foo': '1', 'bar': '2', 'baz': 'aaa'}}]
df = spark.createDataFrame(d)

keys_df = df.select(F.explode(F.map_keys(F.col("Parameters")))).distinct()
keys = list(map(lambda row: row[0], keys_df.collect()))
key_cols = list(map(lambda f: F.col("Parameters").getItem(f).alias(str(f)), keys))
df.select(key_cols).show()
+---+---+---+
|bar|foo|baz|
+---+---+---+
|  2|  1|aaa|
+---+---+---+

Collecting results to the driver node can be a performance bottleneck. It's good to execute this code list(map(lambda row: row[0], keys_df.collect())) as a separate command to make sure it's not running too slowly.

like image 44
Powers Avatar answered Oct 03 '22 22:10

Powers