Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark: Split multiple array columns into rows

I have a dataframe which has one row, and several columns. Some of the columns are single values, and others are lists. All list columns are the same length. I want to split each list column into a separate row, while keeping any non-list column as is.

Sample DF:

from pyspark import Row from pyspark.sql import SQLContext from pyspark.sql.functions import explode  sqlc = SQLContext(sc)  df = sqlc.createDataFrame([Row(a=1, b=[1,2,3],c=[7,8,9], d='foo')]) # +---+---------+---------+---+ # |  a|        b|        c|  d| # +---+---------+---------+---+ # |  1|[1, 2, 3]|[7, 8, 9]|foo| # +---+---------+---------+---+ 

What I want:

+---+---+----+------+ |  a|  b|  c |    d | +---+---+----+------+ |  1|  1|  7 |  foo | |  1|  2|  8 |  foo | |  1|  3|  9 |  foo | +---+---+----+------+ 

If I only had one list column, this would be easy by just doing an explode:

df_exploded = df.withColumn('b', explode('b')) # >>> df_exploded.show() # +---+---+---------+---+ # |  a|  b|        c|  d| # +---+---+---------+---+ # |  1|  1|[7, 8, 9]|foo| # |  1|  2|[7, 8, 9]|foo| # |  1|  3|[7, 8, 9]|foo| # +---+---+---------+---+ 

However, if I try to also explode the c column, I end up with a dataframe with a length the square of what I want:

df_exploded_again = df_exploded.withColumn('c', explode('c')) # >>> df_exploded_again.show() # +---+---+---+---+ # |  a|  b|  c|  d| # +---+---+---+---+ # |  1|  1|  7|foo| # |  1|  1|  8|foo| # |  1|  1|  9|foo| # |  1|  2|  7|foo| # |  1|  2|  8|foo| # |  1|  2|  9|foo| # |  1|  3|  7|foo| # |  1|  3|  8|foo| # |  1|  3|  9|foo| # +---+---+---+---+ 

What I want is - for each column, take the nth element of the array in that column and add that to a new row. I've tried mapping an explode accross all columns in the dataframe, but that doesn't seem to work either:

df_split = df.rdd.map(lambda col: df.withColumn(col, explode(col))).toDF() 
like image 829
Steve Avatar asked Dec 07 '16 21:12

Steve


People also ask

How do you split columns in PySpark?

The PySpark SQL provides the split() function to convert delimiter separated String to an Array (StringType to ArrayType) column on DataFrame It can be done by splitting the string column on the delimiter like space, comma, pipe, etc. and converting it into ArrayType.

How do you flatten an array in PySpark?

If you want to flatten the arrays, use flatten function which converts array of array columns to a single array on DataFrame.

How do I split a column into multiple rows in Python?

split() function is used to break up single column values into multiple columns based on a specified separator or delimiter. The Series. str. split() function is similar to the Python string split() method, but split() method works on the all Dataframe columns, whereas the Series.


2 Answers

Spark >= 2.4

You can replace zip_ udf with arrays_zip function

from pyspark.sql.functions import arrays_zip, col, explode  (df     .withColumn("tmp", arrays_zip("b", "c"))     .withColumn("tmp", explode("tmp"))     .select("a", col("tmp.b"), col("tmp.c"), "d")) 

Spark < 2.4

With DataFrames and UDF:

from pyspark.sql.types import ArrayType, StructType, StructField, IntegerType from pyspark.sql.functions import col, udf, explode  zip_ = udf(   lambda x, y: list(zip(x, y)),   ArrayType(StructType([       # Adjust types to reflect data types       StructField("first", IntegerType()),       StructField("second", IntegerType())   ])) )  (df     .withColumn("tmp", zip_("b", "c"))     # UDF output cannot be directly passed to explode     .withColumn("tmp", explode("tmp"))     .select("a", col("tmp.first").alias("b"), col("tmp.second").alias("c"), "d")) 

With RDDs:

(df     .rdd     .flatMap(lambda row: [(row.a, b, c, row.d) for b, c in zip(row.b, row.c)])     .toDF(["a", "b", "c", "d"])) 

Both solutions are inefficient due to Python communication overhead. If data size is fixed you can do something like this:

from functools import reduce from pyspark.sql import DataFrame  # Length of array n = 3  # For legacy Python you'll need a separate function # in place of method accessor  reduce(     DataFrame.unionAll,      (df.select("a", col("b").getItem(i), col("c").getItem(i), "d")         for i in range(n)) ).toDF("a", "b", "c", "d") 

or even:

from pyspark.sql.functions import array, struct  # SQL level zip of arrays of known size # followed by explode tmp = explode(array(*[     struct(col("b").getItem(i).alias("b"), col("c").getItem(i).alias("c"))     for i in range(n) ]))  (df     .withColumn("tmp", tmp)     .select("a", col("tmp").getItem("b"), col("tmp").getItem("c"), "d")) 

This should be significantly faster compared to UDF or RDD. Generalized to support an arbitrary number of columns:

# This uses keyword only arguments # If you use legacy Python you'll have to change signature # Body of the function can stay the same def zip_and_explode(*colnames, n):     return explode(array(*[         struct(*[col(c).getItem(i).alias(c) for c in colnames])         for i in range(n)     ]))  df.withColumn("tmp", zip_and_explode("b", "c", n=3)) 
like image 130
zero323 Avatar answered Oct 13 '22 06:10

zero323


You'd need to use flatMap, not map as you want to make multiple output rows out of each input row.

from pyspark.sql import Row def dualExplode(r):     rowDict = r.asDict()     bList = rowDict.pop('b')     cList = rowDict.pop('c')     for b,c in zip(bList, cList):         newDict = dict(rowDict)         newDict['b'] = b         newDict['c'] = c         yield Row(**newDict)  df_split = sqlContext.createDataFrame(df.rdd.flatMap(dualExplode)) 
like image 27
David Avatar answered Oct 13 '22 04:10

David