Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What's the most efficient way to accumulate dataframes in pyspark?

I have a dataframe (or could be any RDD) containing several millions row in a well-known schema like this:

Key | FeatureA | FeatureB
--------------------------
U1  |        0 |         1
U2  |        1 |         1

I need to load a dozen other datasets from disk that contains different features for the same number of keys. Some datasets are up to a dozen or so columns wide. Imagine:

Key | FeatureC | FeatureD |  FeatureE
-------------------------------------
U1  |        0 |        0 |         1

Key | FeatureF
--------------
U2  |        1

It feels like a fold or an accumulation where I just want to iterate all the datasets and get back something like this:

Key | FeatureA | FeatureB | FeatureC | FeatureD | FeatureE | FeatureF 
---------------------------------------------------------------------
U1  |        0 |        1 |        0 |        0 |        1 |        0
U2  |        1 |        1 |        0 |        0 |        0 |        1

I've tried loading each dataframe then joining but that takes forever once I get past a handful of datasets. Am I missing a common pattern or efficient way of accomplishing this task?

like image 463
joshua.ewer Avatar asked Oct 18 '22 15:10

joshua.ewer


1 Answers

Assuming there is at most one row per key in each DataFrame and all keys are of primitive types you can try an union with an aggregation. Lets start with some imports and example data:

from itertools import chain
from functools import reduce
from pyspark.sql.types import StructType
from pyspark.sql.functions import col, lit, max
from pyspark.sql import DataFrame

df1 = sc.parallelize([
    ("U1", 0, 1), ("U2", 1, 1)
]).toDF(["Key", "FeatureA", "FeatureB"])

df2 = sc.parallelize([
  ("U1", 0, 0, 1)
]).toDF(["Key", "FeatureC", "FeatureD", "FeatureE"])

df3 = sc.parallelize([("U2", 1)]).toDF(["Key", "FeatureF"])

dfs = [df1, df2, df3]

Next we can extract common schema:

output_schema = StructType(
  [df1.schema.fields[0]] + list(chain(*[df.schema.fields[1:] for df in dfs]))
)

and transform all DataFrames:

transformed_dfs = [df.select(*[
  lit(None).cast(c.dataType).alias(c.name) if c.name not in df.columns 
  else col(c.name)
  for c in output_schema.fields
]) for df in dfs]

Finally an union and dummy aggregation:

combined = reduce(DataFrame.unionAll, transformed_dfs)
exprs = [max(c).alias(c) for c in combined.columns[1:]]
result = combined.repartition(col("Key")).groupBy(col("Key")).agg(*exprs)

If there is more than one row per key but individual columns are still atomic you can try to replace max with collect_list / collect_set followed by explode.

like image 53
zero323 Avatar answered Oct 21 '22 06:10

zero323