How do you interpolate a PySpark dataframe within grouped data?
For example:
I have a PySpark dataframe with the following columns:
+--------+-------------------+--------+
|webID |timestamp |counts |
+--------+-------------------+--------+
|John |2018-02-01 03:00:00|60 |
|John |2018-02-01 03:03:00|66 |
|John |2018-02-01 03:05:00|70 |
|John |2018-02-01 03:08:00|76 |
|Mo |2017-06-04 01:05:00|10 |
|Mo |2017-06-04 01:07:00|20 |
|Mo |2017-06-04 01:10:00|35 |
|Mo |2017-06-04 01:11:00|40 |
+--------+----------------- -+--------+
I need to interpolate both John and Mo's count data to a datapoint every minute, within their own interval. I am open to any simple linear interpolation - but note that my real data is every few seconds and I want to interpolate to every second.
So the result should be:
+--------+-------------------+--------+
|webID |timestamp |counts |
+--------+-------------------+--------+
|John |2018-02-01 03:00:00|60 |
|John |2018-02-01 03:01:00|62 |
|John |2018-02-01 03:02:00|64 |
|John |2018-02-01 03:03:00|66 |
|John |2018-02-01 03:04:00|68 |
|John |2018-02-01 03:05:00|70 |
|John |2018-02-01 03:06:00|72 |
|John |2018-02-01 03:07:00|74 |
|John |2018-02-01 03:08:00|76 |
|Mo |2017-06-04 01:05:00|10 |
|Mo |2017-06-04 01:06:00|15 |
|Mo |2017-06-04 01:07:00|20 |
|Mo |2017-06-04 01:08:00|25 |
|Mo |2017-06-04 01:09:00|30 |
|Mo |2017-06-04 01:10:00|35 |
|Mo |2017-06-04 01:11:00|40 |
+--------+----------------- -+--------+
New rows need to be added to my original dataframe. Looking for a PySpark solution.
If you use Python the shortest way to get things done is to re-use existing Pandas functions, with GROUPED_MAP
udf:
from operator import attrgetter
from pyspark.sql.types import StructType
from pyspark.sql.functions import pandas_udf, PandasUDFType
def resample(schema, freq, timestamp_col = "timestamp",**kwargs):
@pandas_udf(
StructType(sorted(schema, key=attrgetter("name"))),
PandasUDFType.GROUPED_MAP)
def _(pdf):
pdf.set_index(timestamp_col, inplace=True)
pdf = pdf.resample(freq).interpolate()
pdf.ffill(inplace=True)
pdf.reset_index(drop=False, inplace=True)
pdf.sort_index(axis=1, inplace=True)
return pdf
return _
Applied on your data:
from pyspark.sql.functions import to_timestamp
df = spark.createDataFrame([
("John", "2018-02-01 03:00:00", 60),
("John", "2018-02-01 03:03:00", 66),
("John", "2018-02-01 03:05:00", 70),
("John", "2018-02-01 03:08:00", 76),
("Mo", "2017-06-04 01:05:00", 10),
("Mo", "2017-06-04 01:07:00", 20),
("Mo", "2017-06-04 01:10:00", 35),
("Mo", "2017-06-04 01:11:00", 40),
], ("webID", "timestamp", "counts")).withColumn(
"timestamp", to_timestamp("timestamp")
)
df.groupBy("webID").apply(resample(df.schema, "60S")).show()
it yields
+------+-------------------+-----+
|counts| timestamp|webID|
+------+-------------------+-----+
| 60|2018-02-01 03:00:00| John|
| 62|2018-02-01 03:01:00| John|
| 64|2018-02-01 03:02:00| John|
| 66|2018-02-01 03:03:00| John|
| 68|2018-02-01 03:04:00| John|
| 70|2018-02-01 03:05:00| John|
| 72|2018-02-01 03:06:00| John|
| 74|2018-02-01 03:07:00| John|
| 76|2018-02-01 03:08:00| John|
| 10|2017-06-04 01:05:00| Mo|
| 15|2017-06-04 01:06:00| Mo|
| 20|2017-06-04 01:07:00| Mo|
| 25|2017-06-04 01:08:00| Mo|
| 30|2017-06-04 01:09:00| Mo|
| 35|2017-06-04 01:10:00| Mo|
| 40|2017-06-04 01:11:00| Mo|
+------+-------------------+-----+
This works under the assumption that both input and interpolated data for a single webID
can fit in a memory of a single node (in general other exact, non-iterative solutions will have to make similar assumptions). If that's not the case you can easily approximate by taking overlapping windows
partial = (df
.groupBy("webID", window("timestamp", "5 minutes", "3 minutes")["start"])
.apply(resample(df.schema, "60S")))
and aggregating the final result
from pyspark.sql.functions import mean
(partial
.groupBy("webID", "timestamp")
.agg(mean("counts")
.alias("counts"))
# Order by key and timestamp, only for consistent presentation
.orderBy("webId", "timestamp")
.show())
This of course is much more expensive (there are two shuffles, and some values will be computed multiple times), but also can leave gaps, if overlap is not large enough to include the next observation.
+-----+-------------------+------+
|webID| timestamp|counts|
+-----+-------------------+------+
| John|2018-02-01 03:00:00| 60.0|
| John|2018-02-01 03:01:00| 62.0|
| John|2018-02-01 03:02:00| 64.0|
| John|2018-02-01 03:03:00| 66.0|
| John|2018-02-01 03:04:00| 68.0|
| John|2018-02-01 03:05:00| 70.0|
| John|2018-02-01 03:08:00| 76.0|
| Mo|2017-06-04 01:05:00| 10.0|
| Mo|2017-06-04 01:06:00| 15.0|
| Mo|2017-06-04 01:07:00| 20.0|
| Mo|2017-06-04 01:08:00| 25.0|
| Mo|2017-06-04 01:09:00| 30.0|
| Mo|2017-06-04 01:10:00| 35.0|
| Mo|2017-06-04 01:11:00| 40.0|
+-----+-------------------+------+
A native pyspark implementation (no udf's) that tackles this problem is:
import pyspark.sql.functions as F
resample_interval = 1 # Resample interval size in seconds
df_interpolated = (
df_data
# Get timestamp and Counts of previous measurement via window function
.selectExpr(
"webID",
"LAG(Timestamp) OVER (PARTITION BY webID ORDER BY Timestamp ASC) as PreviousTimestamp",
"Timestamp as NextTimestamp",
"LAG(Counts) OVER (PARTITION BY webID ORDER BY Timestamp ASC) as PreviousCounts",
"Counts as NextCounts",
)
# To determine resample interval round up start and round down end timeinterval to nearest interval boundary
.withColumn("PreviousTimestampRoundUp", F.expr(f"to_timestamp(ceil(unix_timestamp(PreviousTimestamp)/{resample_interval})*{resample_interval})"))
.withColumn("NextTimestampRoundDown", F.expr(f"to_timestamp(floor(unix_timestamp(NextTimestamp)/{resample_interval})*{resample_interval})"))
# Make sure we don't get any negative intervals (whole interval is within resample interval)
.filter("PreviousTimestampRoundUp<=NextTimestampRoundDown")
# Create resampled time axis by creating all "interval" timestamps between previous and next timestamp
.withColumn("Timestamp", F.expr(f"explode(sequence(PreviousTimestampRoundUp, NextTimestampRoundDown, interval {resample_interval} second)) as Timestamp"))
# Sequence has inclusive boundaries for both start and stop. Filter out duplicate Counts if original timestamp is exactly a boundary.
.filter("Timestamp<NextTimestamp")
# Interpolate Counts between previous and next
.selectExpr(
"webID",
"Timestamp",
"""(unix_timestamp(Timestamp)-unix_timestamp(PreviousTimestamp))
/(unix_timestamp(NextTimestamp)-unix_timestamp(PreviousTimestamp))
*(NextCounts-PreviousCounts)
+PreviousCounts
as Counts"""
)
)
I recently have written a blogpost explaining this method and showing that this method scales much better for big datasets compared to the pandas udf method above : https://medium.com/delaware-pro/interpolate-big-data-time-series-in-native-pyspark-d270d4b592a1
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With