I am using Spark 1.3
# Read from text file, parse it and then do some basic filtering to get data1 data1.registerTempTable('data1') # Read from text file, parse it and then do some basic filtering to get data1 data2.registerTempTable('data2') # Perform join data_joined = data1.join(data2, data1.id == data2.id);
My data is quite skewed and data2 (few KB) << data1 (10s of GB) and the performance is quite bad. I was reading about broadcast join, but not sure how I can do the same using Python API.
Broadcast variables are used to save the copy of data across all nodes. This variable is cached on all the machines and not sent on machines with tasks. The following code block has the details of a Broadcast class for PySpark.
Broadcast join in spark is preferred when we want to join one small data frame with the large one. the requirement here is we should be able to store the small data frame easily in the memory so that we can join them with the large data frame in order to boost the performance of the join.
Broadcast join is an important part of Spark SQL's execution engine. When used, it performs a join on two relations by first broadcasting the smaller one to all Spark executors, then evaluating the join criteria with each executor's partitions of the other relation.
Spark 1.3 doesn't support broadcast joins using DataFrame. In Spark >= 1.5.0 you can use broadcast
function to apply broadcast joins:
from pyspark.sql.functions import broadcast data1.join(broadcast(data2), data1.id == data2.id)
For older versions the only option is to convert to RDD and apply the same logic as in other languages. Roughly something like this:
from pyspark.sql import Row from pyspark.sql.types import StructType # Create a dictionary where keys are join keys # and values are lists of rows data2_bd = sc.broadcast( data2.map(lambda r: (r.id, r)).groupByKey().collectAsMap()) # Define a new row with fields from both DFs output_row = Row(*data1.columns + data2.columns) # And an output schema output_schema = StructType(data1.schema.fields + data2.schema.fields) # Given row x, extract a list of corresponding rows from broadcast # and output a list of merged rows def gen_rows(x): return [output_row(*x + y) for y in data2_bd.value.get(x.id, [])] # flatMap and create a new data frame joined = data1.rdd.flatMap(lambda row: gen_rows(row)).toDF(output_schema)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With