Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Convert spark DataFrame column to python list

See, why this way that you are doing is not working. First, you are trying to get integer from a Row Type, the output of your collect is like this:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

If you take something like this:

>>> firstvalue = mvv_list[0].mvv
Out: 1

You will get the mvv value. If you want all the information of the array you can take something like this:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

But if you try the same for the other column, you get:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

This happens because count is a built-in method. And the column has the same name as count. A workaround to do this is change the column name of count to _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

But this workaround is not needed, as you can access the column using the dictionary syntax:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

And it will finally work!


Following one liner gives the list you want.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

This will give you all the elements as a list.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

I ran a benchmarking analysis and list(mvv_count_df.select('mvv').toPandas()['mvv']) is the fastest method. I'm very surprised.

I ran the different approaches on 100 thousand / 100 million row datasets using a 5 node i3.xlarge cluster (each node has 30.5 GBs of RAM and 4 cores) with Spark 2.4.5. Data was evenly distributed on 20 snappy compressed Parquet files with a single column.

Here's the benchmarking results (runtimes in seconds):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Golden rules to follow when collecting data on the driver node:

  • Try to solve the problem with other approaches. Collecting data to the driver node is expensive, doesn't harness the power of the Spark cluster, and should be avoided whenever possible.
  • Collect as few rows as possible. Aggregate, deduplicate, filter, and prune columns before collecting the data. Send as little data to the driver node as you can.

toPandas was significantly improved in Spark 2.3. It's probably not the best approach if you're using a Spark version earlier than 2.3.

See here for more details / benchmarking results.


The following code will help you

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()