I have a dataframe (or could be any RDD) containing several millions row in a well-known schema like this:
Key | FeatureA | FeatureB
--------------------------
U1 | 0 | 1
U2 | 1 | 1
I need to load a dozen other datasets from disk that contains different features for the same number of keys. Some datasets are up to a dozen or so columns wide. Imagine:
Key | FeatureC | FeatureD | FeatureE
-------------------------------------
U1 | 0 | 0 | 1
Key | FeatureF
--------------
U2 | 1
It feels like a fold or an accumulation where I just want to iterate all the datasets and get back something like this:
Key | FeatureA | FeatureB | FeatureC | FeatureD | FeatureE | FeatureF
---------------------------------------------------------------------
U1 | 0 | 1 | 0 | 0 | 1 | 0
U2 | 1 | 1 | 0 | 0 | 0 | 1
I've tried loading each dataframe then joining but that takes forever once I get past a handful of datasets. Am I missing a common pattern or efficient way of accomplishing this task?
Assuming there is at most one row per key in each DataFrame
and all keys are of primitive types you can try an union with an aggregation. Lets start with some imports and example data:
from itertools import chain
from functools import reduce
from pyspark.sql.types import StructType
from pyspark.sql.functions import col, lit, max
from pyspark.sql import DataFrame
df1 = sc.parallelize([
("U1", 0, 1), ("U2", 1, 1)
]).toDF(["Key", "FeatureA", "FeatureB"])
df2 = sc.parallelize([
("U1", 0, 0, 1)
]).toDF(["Key", "FeatureC", "FeatureD", "FeatureE"])
df3 = sc.parallelize([("U2", 1)]).toDF(["Key", "FeatureF"])
dfs = [df1, df2, df3]
Next we can extract common schema:
output_schema = StructType(
[df1.schema.fields[0]] + list(chain(*[df.schema.fields[1:] for df in dfs]))
)
and transform all DataFrames
:
transformed_dfs = [df.select(*[
lit(None).cast(c.dataType).alias(c.name) if c.name not in df.columns
else col(c.name)
for c in output_schema.fields
]) for df in dfs]
Finally an union and dummy aggregation:
combined = reduce(DataFrame.unionAll, transformed_dfs)
exprs = [max(c).alias(c) for c in combined.columns[1:]]
result = combined.repartition(col("Key")).groupBy(col("Key")).agg(*exprs)
If there is more than one row per key but individual columns are still atomic you can try to replace max
with collect_list
/ collect_set
followed by explode
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With