Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using Python's reduce() to join multiple PySpark DataFrames

Does anyone know why using Python3's functools.reduce() would lead to worse performance when joining multiple PySpark DataFrames than just iteratively joining the same DataFrames using a for loop? Specifically, this gives a massive slowdown followed by an out-of-memory error:

def join_dataframes(list_of_join_columns, left_df, right_df):
    return left_df.join(right_df, on=list_of_join_columns)

joined_df = functools.reduce(
    functools.partial(join_dataframes, list_of_join_columns), list_of_dataframes,
)

whereas this one doesn't:

joined_df = list_of_dataframes[0]
joined_df.cache()
for right_df in list_of_dataframes[1:]:
    joined_df = joined_df.join(right_df, on=list_of_join_columns)

Any ideas would be greatly appreciated. Thanks!

like image 798
Eric Smith Avatar asked Oct 29 '22 05:10

Eric Smith


2 Answers

One reason is that a reduce or a fold is usually functionally pure: the result of each accumulation operation is not written to the same part of memory, but rather to a new block of memory.

In principle the garbage collector could free the previous block after each accumulation, but if it doesn't you'll allocate memory for each updated version of the accumulator.

like image 32
Alex Avatar answered Nov 15 '22 06:11

Alex


As long as you use CPython (different implementations can, but realistically shouldn't, exhibit significantly different behavior in this specific case). If you take a look at reduce implementation you'll see it is just a for-loop with minimal exception handling.

The core is exactly equivalent to the loop you use

for element in it:
    value = function(value, element)

and there is no evidence supporting claims of any special behavior.

Additionally simple tests with number of frames practical limitations of Spark joins (joins are among the most expensive operations in Spark)

dfs = [
    spark.range(10000).selectExpr(
        "rand({}) AS id".format(i), "id AS value",  "{} AS loop ".format(i)
    )
    for i in range(200)
]

Show no significant difference in timing between direct for-loop

def f(dfs):
    df1 = dfs[0]
    for df2 in dfs[1:]:
        df1 = df1.join(df2, ["id"])
    return df1

%timeit -n3 f(dfs)                 
## 6.25 s ± 257 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)

and reduce invocation

from functools import reduce

def g(dfs):
    return reduce(lambda x, y: x.join(y, ["id"]), dfs) 

%timeit -n3 g(dfs)
### 6.47 s ± 455 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)

Similarly overall JVM behavior patterns are comparable between for-loop

For loop CPU and Memory Usage - VisualVM

and reduce

reduce CPU and Memory Usage - VisualVM

Finally both generate identical execution plans

g(dfs)._jdf.queryExecution().optimizedPlan().equals( 
    f(dfs)._jdf.queryExecution().optimizedPlan()
)
## True

which indicates no difference when plans is evaluated and OOMs are likely to occur.

In other words you correlation doesn't imply causation, and observed performance problems are unlikely to be related to the method you use to combine DataFrames.

like image 182
user11024414 Avatar answered Nov 15 '22 05:11

user11024414