Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

get first N elements from dataframe ArrayType column in pyspark

I have a spark dataframe with rows as -

1   |   [a, b, c]
2   |   [d, e, f]
3   |   [g, h, i]

Now I want to keep only the first 2 elements from the array column.

1   |   [a, b]
2   |   [d, e]
3   |   [g, h]

How can that be achieved?

Note - Remember that I am not extracting a single array element here, but a part of the array which may contain multiple elements.

like image 290
Vipul Sharma Avatar asked Oct 24 '18 18:10

Vipul Sharma


People also ask

How do you get the first value of a column in PySpark?

To do this we will use the first() and head() functions. Syntax: dataframe. first()['column name']

How do you get the first n rows of PySpark DataFrame?

In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).

How do you use ArrayType in PySpark?

Create ArrayType columnCreate a DataFrame with an array column. Print the schema of the DataFrame to verify that the numbers column is an array. numbers is an array of long elements. We can also create this DataFrame using the explicit StructType syntax.

What is ArrayType in PySpark?

The PySpark ArrayType is widely used and is defined as the collection data type that extends the DataType class which is the superclass of all types in the PySpark. All elements of ArrayType should have the same type of elements.


1 Answers

Here's how to do it with the API functions.

Suppose your DataFrame were the following:

df.show()
#+---+---------+
#| id|  letters|
#+---+---------+
#|  1|[a, b, c]|
#|  2|[d, e, f]|
#|  3|[g, h, i]|
#+---+---------+

df.printSchema()
#root
# |-- id: long (nullable = true)
# |-- letters: array (nullable = true)
# |    |-- element: string (containsNull = true)

You can use square brackets to access elements in the letters column by index, and wrap that in a call to pyspark.sql.functions.array() to create a new ArrayType column.

import pyspark.sql.functions as f

df.withColumn("first_two", f.array([f.col("letters")[0], f.col("letters")[1]])).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

Or if you had too many indices to list, you can use a list comprehension:

df.withColumn("first_two", f.array([f.col("letters")[i] for i in range(2)])).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

For pyspark versions 2.4+ you can also use pyspark.sql.functions.slice():

df.withColumn("first_two",f.slice("letters",start=1,length=2)).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

slice may have better performance for large arrays (note that start index is 1, not 0)

like image 127
pault Avatar answered Sep 20 '22 16:09

pault