I have two files in which I store:
The IPs are stored as integers (using inet_aton()).
I tried using Spark SQL to join these pieces of data by loading both files into dataframes and registering them as temp tables.
GeoLocTable - ipstart, ipend, ...additional Geo location data
Recordstable - INET_ATON, ...3 more fields
I tried using Spark SQL to join these pieces of data using a SQL statement like so -
"select a.*, b.* from Recordstable a left join GeoLocTable b on a.INET_ATON between b.ipstart and b.ipend"
There are about 850K records in RecordsTable and about 2.5M records in GeoLocTable. The join as it exists runs for about 2 hours with about 20 executors.
I have tried caching and broadcasting the GeoLocTable but it does not really seem to help. I have bumped up spark.sql.autoBroadcastJoinThreshold=300000000 and spark.sql.shuffle.partitions=600.
Spark UI shows a BroadcastNestedLoopJoin being performed. Is this the best I should be expecting? I tried searching for conditions where this type of join would be performed but the documentation seems sparse.
PS - I am using PySpark to work with Spark.
The source of the problem is pretty simple. When you execute join and join condition is not equality based the only thing that Spark can do right now is expand it to Cartesian product followed by filter what is pretty much what happens inside BroadcastNestedLoopJoin
. So logically you have this huge nested loop which tests all 850K * 2.5M records.
This approach is obviously extremely inefficient. Since it looks like lookup table fits into memory the simplest improvement is to use local, sorted data structure instead of Spark DataFrame
. Assuming your data looks like this:
geo_loc_table = sc.parallelize([
(1, 10, "foo"), (11, 36, "bar"), (37, 59, "baz"),
]).toDF(["ipstart", "ipend", "loc"])
records_table = sc.parallelize([
(1, 11), (2, 38), (3, 50)
]).toDF(["id", "inet"])
We can project and sort reference data by ipstart
and create broadcast variable:
geo_start_bd = sc.broadcast(geo_loc_table
.select("ipstart")
.orderBy("ipstart")
.flatMap(lambda x: x)
.collect())
Next we'll use an UDF and bisect module to augment records_table
from bisect import bisect_right
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType
# https://docs.python.org/3/library/bisect.html#searching-sorted-lists
def find_le(x):
'Find rightmost value less than or equal to x'
i = bisect_right(geo_start_bd.value, x)
if i:
return geo_start_bd.value[i-1]
return None
records_table_with_ipstart = records_table.withColumn(
"ipstart", udf(find_le, LongType())("inet")
)
and finally join both datasets:
records_table_with_ipstart.join(geo_loc_table, ["ipstart"], "left")
Another possibility is to use the Python version of the join_with_range
API in Apache DataFu to do the join. This will explode your ranges into multiple rows so Spark can still do an equi-join.
You need to call PySpark with the following parameters (taken from here).
export PYTHONPATH=datafu-spark_2.11-1.6.0.jar
pyspark --jars datafu-spark_2.11-1.6.0-SNAPSHOT.jar --conf spark.executorEnv.PYTHONPATH=datafu-spark_2.11-1.6.0-SNAPSHOT.jar
And then you would do the join like this:
from pyspark_utils.df_utils import PySparkDFUtils
df_utils = PySparkDFUtils()
func_joinWithRange_res = df_utils.join_with_range(df_single=records_table,col_single="INET_ATON",df_range=geo_loc_table,col_range_start="ipstart",col_range_end="ipend",decrease_factor=10)
func_joinWithRange_res.registerTempTable("joinWithRange")
The argument 10
is to minimize the amount of exploded rows: it affects the number of "buckets" created. You can play with this in order to improve the performance.
Full disclosure - I am a member of DataFu.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With