Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Evaluating Spark DataFrame in loop slows down with every iteration, all work done by controller

I am trying to use a Spark cluster (running on AWS EMR) to link groups of items that have common elements in them. Essentially, I have groups with some elements and if some of the elements are in multiple groups, I want to make one group that contains elements from all of those groups.

I know about GraphX library and I tried to use graphframes package (ConnectedComponents algorithm) to resolve this task, but it seams that the graphframes package is not yet mature enough and is very wasteful with resources... Running it on my data set (cca 60GB) it just runs out of memory no matter how much I tune the Spark parameters, how I partition and re-partion my data or how big cluster I create (the graph IS huge).

So I wrote my own code do accomplish the task. The code works and it solves my problem, but it slows down with every iteration. Since it can take sometimes around 10 iterations to finish, it can run very long and I could not figure out what the problem is.

I start with a table (DataFrame) item_links that has two columns: item and group_name. Items are unique within each group, but not within this table. One item can be in multiple groups. If two items each have a row with the same group name, they both belong to the same group.

I first group by item and find for every item the smallest of all group names from all groups that it belongs to. I append this information as an extra column to the original DataFrame. Then I create a new DataFrame by groupping by the group name and finding the smallest value of this new column within every group. I join this DataFrame with my original table on the group name and replace the group name column with the minimum value from that new column. The idea is, that if a group contains an item that also belongs to some smaller group, this group will be merged it. In every iteration it links groups that were indirectly linked by more and more items in between.

The code that I am running looks like this:

print(" Merging groups that have common items...")

n_partitions = 32

merge_level = 0

min_new_group = "min_new_group_{}".format(merge_level)

# For every item identify the (alphabetically) first group in which this item was found
# and add a new column min_new_group with that information for every item.
first_group = item_links \
                    .groupBy('item') \
                    .agg( min('group_name').alias(min_new_group) ) \
                    .withColumnRenamed('item', 'item_id') \
                    .coalesce(n_partitions) \
                    .cache()

item_links = item_links \
                .join( first_group,
                       item_links['item'] == first_group['item_id'] ) \
                .drop(first_group['item_id']) \
                .coalesce(n_partitions) \
                .cache()

first_group.unpersist()

# In every group find the (alphabetically) smallest min_new_group value.
# If the group contains a item that was in some other group,
# this value will be different than the current group_name.
merged_groups = item_links \
                    .groupBy('group_name') \
                    .agg(
                        min(col(min_new_group)).alias('merged_group')
                    ) \
                    .withColumnRenamed('group_name', 'group_to_merge') \
                    .coalesce(n_partitions) \
                    .cache()

# Replace the group_name column with the lowest group that any of the item in the group had.
item_links = item_links \
                .join( merged_groups,
                       item_links['group_name'] == merged_groups['group_to_merge'] ) \
                .drop(item_links['group_name']) \
                .drop(merged_groups['group_to_merge']) \
                .drop(item_links[min_new_group]) \
                .withColumnRenamed('merged_group', 'group_name') \
                .coalesce(n_partitions) \
                .cache()

# Count the number of common items found
common_items_count = merged_groups.filter(col('merged_group') != col('group_to_merge')).count()

merged_groups.unpersist()

# just some debug output
print("  level {}: found {} common items".format(merge_level, common_items_count))

# As long as the number of groups keep decreasing (groups are merged together), repeat the operation.
while (common_items_count > 0):
    merge_level += 1

    min_new_group = "min_new_group_{}".format(merge_level)

    # for every item find new minimal group...
    first_group = item_links \
                        .groupBy('item') \
                        .agg(
                            min(col('group_name')).alias(min_new_group)
                        ) \
                        .withColumnRenamed('item', 'item_id') \
                        .coalesce(n_partitions) \
                        .cache() 

    item_links = item_links \
                    .join( first_group,
                           item_links['item'] == first_group['item_id'] ) \
                    .drop(first_group['item']) \
                    .coalesce(n_partitions) \
                    .cache()

    first_group.unpersist()

    # find groups that have items from other groups...
    merged_groups = item_links \
                        .groupBy(col('group_name')) \
                        .agg(
                            min(col(min_new_group)).alias('merged_group')
                        ) \
                        .withColumnRenamed('group_name', 'group_to_merge') \
                        .coalesce(n_partitions) \
                        .cache()

    # merge the groups with items from other groups...
    item_links = item_links \
                    .join( merged_groups,
                           item_links['group_name'] == merged_groups['group_to_merge'] ) \
                    .drop(item_links['group_name']) \
                    .drop(merged_groups['group_to_merge']) \
                    .drop(item_links[min_new_group]) \
                    .withColumnRenamed('merged_group', 'group_name') \
                    .coalesce(n_partitions) \
                    .cache()

    common_items_count = merged_groups.filter(col('merged_group') != col('group_to_merge')).count()

    merged_groups.unpersist()

    print("  level {}: found {} common items".format(merge_level, common_items_count))

As I said, it works, but the problem is, that it slows down with every iteration. The iterations 1-3 run just a few seconds or minutes. Iteration 5 runs around 20-40 minutes. Iteration 6 sometimes doesn't even finish, because controller runs out of memory (14 GB for controller, around 140 GB of RAM for the entire cluster with 20 CPU cores... the test data is around 30 GB).

When I monitor the cluster in Ganglia, I see, that after every iteration the workers perform less and less work and the controller performs more and more. The network traffic also goes down to zero. Memory usage is rather stable after the initial phase.

I read lot about re-partitioning, turning Spark parameters and background of shuffle operations and I did my best to optimize everything, but I have no idea what's going on here. Below is a load of my cluster nodes (yellow for controller node) over time as the code above is running.

load of cluster nodes, yellow is controller

like image 530
grepe Avatar asked Aug 22 '16 16:08

grepe


People also ask

How to iterate through Dataframe in pyspark?

Below are some examples to iterate through DataFrame using for each. If you have a small dataset, you can also Convert PySpark DataFrame to Pandas and use pandas to iterate through. Use spark.sql.execution.arrow.enabled config to enable Apache Arrow with Spark.

How to use arrow with spark dataframe?

If you have a small dataset, you can also Convert PySpark DataFrame to Pandas and use pandas to iterate through. Use spark.sql.execution.arrow.enabled config to enable Apache Arrow with Spark.

What is the use of foreach () function in spark?

This is different than other actions as foreach () function doesn’t return a value instead it executes input function on each element of an RDD, DataFrame, and Dataset. When foreach () applied on Spark DataFrame, it executes a function specified in for each element of DataFrame/Dataset.

How to iterate over a loop from collected elements in Dataframe?

Here an iterator is used to iterate over a loop from the collected elements using the collect () method. for itertator in dataframe.collect (): print (itertator ["column_name"],...............)


2 Answers

Try printing dataFrame.explain to see the logical plan. Every iteration the transformations on this Dataframe keeps on adding up to the logical plan, and so the evaluation time keeps on adding up.

You can use below solution as a workaround :

dataFRame.rdd.localCheckpoint()

This writes the RDDs for this DataFrame to memory and removes the lineages , and then created the RDD from the data written to the memory.

Good thing about this is that you dont need to write your RDD to HDFS or disk. However, this also brings some issues with it, which may or may not effect you. You can read the documentation of "localCheckPointing" method for details.

like image 103
Amanpreet Khurana Avatar answered Oct 08 '22 10:10

Amanpreet Khurana


I resolved this issue by saving the DataFrame to HDFS at the end of every iteration and reading it back from HDFS in the beginning of the next one.

Since I do that, the program runs as a breeze and doesn't show any signs of slowing down, overusing the memory or overloading the driver.

I still don't understand why this happens, so I'm leaving the question open.

like image 21
grepe Avatar answered Oct 08 '22 09:10

grepe