See, why this way that you are doing is not working. First, you are trying to get integer from a Row Type, the output of your collect is like this:
>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)
If you take something like this:
>>> firstvalue = mvv_list[0].mvv
Out: 1
You will get the mvv
value. If you want all the information of the array you can take something like this:
>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]
But if you try the same for the other column, you get:
>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'
This happens because count
is a built-in method. And the column has the same name as count
. A workaround to do this is change the column name of count
to _count
:
>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]
But this workaround is not needed, as you can access the column using the dictionary syntax:
>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]
And it will finally work!
Following one liner gives the list you want.
mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()
This will give you all the elements as a list.
mvv_list = list(
mvv_count_df.select('mvv').toPandas()['mvv']
)
I ran a benchmarking analysis and list(mvv_count_df.select('mvv').toPandas()['mvv'])
is the fastest method. I'm very surprised.
I ran the different approaches on 100 thousand / 100 million row datasets using a 5 node i3.xlarge cluster (each node has 30.5 GBs of RAM and 4 cores) with Spark 2.4.5. Data was evenly distributed on 20 snappy compressed Parquet files with a single column.
Here's the benchmarking results (runtimes in seconds):
+-------------------------------------------------------------+---------+-------------+
| Code | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect() | 0.4 | 55.3 |
| list(df.select('col_name').toPandas()['col_name']) | 0.4 | 17.5 |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()| 0.9 | 69 |
| [row[0] for row in df.select('col_name').collect()] | 1.0 | OOM |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] | 1.2 | * |
+-------------------------------------------------------------+---------+-------------+
* cancelled after 800 seconds
Golden rules to follow when collecting data on the driver node:
toPandas
was significantly improved in Spark 2.3. It's probably not the best approach if you're using a Spark version earlier than 2.3.
See here for more details / benchmarking results.
The following code will help you
mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With