Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to explode multiple columns, different types and different lengths?

Tags:

python

pyspark

I've got a DF with columns of different time cycles (1/6, 3/6, 6/6 etc.) and would like to "explode" all the columns to create a new DF in which each row is a 1/6 cycle.

from pyspark import Row 
from pyspark.sql import SparkSession 
from pyspark.sql.functions import explode, arrays_zip, col

spark = SparkSession.builder \
    .appName('DataFrame') \
    .master('local[*]') \
    .getOrCreate()

df = spark.createDataFrame([Row(a=1, b=[1, 2, 3, 4, 5, 6], c=[11, 22, 33], d=['foo'])])

|  a|                 b|           c|    d|
+---+------------------+------------+-----+
|  1|[1, 2, 3, 4, 5, 6]|[11, 22, 33]|[foo]|
+---+------------------+------------+-----+

I'm doing the explode:

df2 = (df.withColumn("tmp", arrays_zip("b", "c", "d"))
       .withColumn("tmp", explode("tmp"))
       .select("a", col("tmp.b"), col("tmp.c"), "d"))

But the output is not what I want:

|  a|  b|   c|    d|
+---+---+----+-----+
|  1|  1|  11|[foo]|
|  1|  2|  22|[foo]|
|  1|  3|  33|[foo]|
|  1|  4|null|[foo]|
|  1|  5|null|[foo]|
|  1|  6|null|[foo]|
+---+---+----+-----+

I would want it to look like this:

|  a|  b|  c|  d|
+---+---+---+---+
|  1|  1| 11|foo|
|   |  2|   |   |
|   |  3| 22|   |
|   |  4|   |   |
|   |  5| 33|   |
|   |  6|   |   |
+---+---+---+---+

I am new to Spark and from the start I've got complicated topics ! :)

Update 2019-07-15: Maybe someone has a solution without usage of UDFs? -> answered by @jxc

Update 2019-07-17: Maybe someone has a solution how to change the null <-> values sequences in more complicated order? Like in c - Null, 11, Null, 22, Null, 33 or more complex situation as we want in column d first value to be Null, next foo then Null, Null, Null:

|  a|  b|  c|  d|
+---+---+---+---+
|  1|  1|   |   |
|   |  2| 11|foo|
|   |  3|   |   |
|   |  4| 22|   |
|   |  5|   |   |
|   |  6| 33|   |
+---+---+---+---+
like image 713
cincin21 Avatar asked Jul 08 '19 08:07

cincin21


People also ask

How do you explode multiple columns?

Column(s) to explode. For multiple columns, specify a non-empty list with each element be str or tuple, and all specified columns their list-like data on same row of the frame must have matching length. If True, the resulting index will be labeled 0, 1, …, n - 1. New in version 1.1.

How do you explode a list in a DataFrame?

DataFrame - explode() functionThe explode() function is used to transform each element of a list-like to a row, replicating the index values. Exploded lists to rows of the subset columns; index will be duplicated for these rows. Raises: ValueError - if columns of the frame are not unique.

What is the opposite of explode in pandas?

explode #45459.

How do you explode a list in Python?

The explode() is a Python function used to transform or modify each member of an array or list into a row. The explode() function converts the list elements to a row while replacing the index values and returning the DataFrame exploded lists.


1 Answers

Here is one way without using udf:

UPDATE on 2019/07/17: adjusted SQL stmt and added N=6 as parameter to SQL.

UPDATE on 2019/07/16: removed the temporary column t, replaced with a constant array(0,1,2,3,4,5) in the transform function. In such case, we can operate on the value of the array elements directly instead of their indexes.

UPDATE: I removed the original method which uses String functions and converts data types in the array elements all into String and less efficient. The Spark SQL higher-order functions with Spark 2.4+ should be better than the original method.

Setup

from pyspark.sql import functions as F, Row

df = spark.createDataFrame([ Row(a=1, b=[1, 2, 3, 4, 5, 6], c=['11', '22', '33'], d=['foo'], e=[111,222]) ])

>>> df.show()
+---+------------------+------------+-----+----------+
|  a|                 b|           c|    d|         e|
+---+------------------+------------+-----+----------+
|  1|[1, 2, 3, 4, 5, 6]|[11, 22, 33]|[foo]|[111, 222]|
+---+------------------+------------+-----+----------+

# columns you want to do array-explode
cols = df.columns

# number of array elements to set
N = 6

Using SQL higher-order function: transform

Use the Spark SQL higher-order function: transform(), do the following:

  1. create the following Spark SQL code where {0} will be replaced by the column_name, {1} will be replaced by N:

    stmt = '''
       CASE
          WHEN '{0}' in ('d') THEN
            transform(sequence(0,{1}-1), x -> IF(x == 1, `{0}`[0], NULL))
          WHEN size(`{0}`) <= {1}/2 AND size(`{0}`) > 1 THEN
            transform(sequence(0,{1}-1), x -> IF(((x+1)*size(`{0}`))%{1} == 0, `{0}`[int((x-1)*size(`{0}`)/{1})], NULL))
          ELSE `{0}`
        END AS `{0}`
    '''
    

    Note: array transformation only defined when array contains more than one (unless specified in a separate WHEN clause) and <= N/2 elements (in this example, 1 < size <= 3). arrays with other size will be kept as-is.

  2. Run the above SQL with selectExpr() for all required columns

    df1 = df.withColumn('a', F.array('a')) \
            .selectExpr(*[ stmt.format(c,N) for c in cols ])
    
    >>> df1.show()
    +---+------------------+----------------+-----------+---------------+
    |  a|                 b|               c|          d|              e|
    +---+------------------+----------------+-----------+---------------+
    |[1]|[1, 2, 3, 4, 5, 6]|[, 11,, 22,, 33]|[, foo,,,,]|[,, 111,,, 222]|
    +---+------------------+----------------+-----------+---------------+
    
  3. run arrays_zip and explode:

    df_new = df1.withColumn('vals', F.explode(F.arrays_zip(*cols))) \
                .select('vals.*') \
                .fillna('', subset=cols)
    
    >>> df_new.show()
    +----+---+---+---+----+
    |   a|  b|  c|  d|   e|
    +----+---+---+---+----+
    |   1|  1|   |   |null|
    |null|  2| 11|foo|null|
    |null|  3|   |   | 111|
    |null|  4| 22|   |null|
    |null|  5|   |   |null|
    |null|  6| 33|   | 222|
    +----+---+---+---+----+
    

    Note: fillna('', subset=cols) only changed columns containing Strings

In one method chain:

df_new = df.withColumn('a', F.array('a')) \
           .selectExpr(*[ stmt.format(c,N) for c in cols ]) \
           .withColumn('vals', F.explode(F.arrays_zip(*cols))) \
           .select('vals.*') \
           .fillna('', subset=cols)

Explanation with the transform function:

The transform function (list below, reflect to an old revision of requirements)

transform(sequence(0,5), x -> IF((x*size({0}))%6 == 0, {0}[int(x*size({0})/6)], NULL))

As mentioned in the post, {0} will be replaced with column name. Here we use column-c which contains 3 elements as an example:

  • In the transform function, sequence(0,5) creates a constant array array(0,1,2,3,4,5) with 6 elements, and the rest sets the lambda function with one argument x having the value of elements.
  • IF(condition, true_value, false_value): is a standard SQL function
  • The condition we applied is: (x*size(c))%6 == 0 where size(c)=3, if this condition is true, it will return c[int(x*size(c)/6)], otherwise, return NULL. so for x from 0 to 5, we will have:

    ((0*3)%6)==0) true   -->  c[int(0*3/6)] = c[0]
    ((1*3)%6)==0) false  -->  NULL
    ((2*3)%6)==0) true   -->  c[int(2*3/6)] = c[1]
    ((3*3)%6)==0) false  -->  NULL
    ((4*3)%6)==0) true   -->  c[int(4*3/6)] = c[2]
    ((5*3)%6)==0) false  -->  NULL
    

Similar to column-e which contains a 2-element array.

like image 199
jxc Avatar answered Sep 21 '22 21:09

jxc