Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Date Arithmetic with Multiple Columns in PySpark

I'm trying to do some medium-complicated date arithmetic using multiple columns in a PySpark data frame. Basically, I have a column called number that represents the number of weeks after the created_at timestamp I need to filter on. In PostgreSQL you can multiply an interval based on the value in a column, but I can't seem to figure out how to do this in PySpark using either the SQL API or the Python API. Any help here would be much appreciated!

import datetime
from pyspark.sql import SQLContext
from pyspark.sql import Row
from pyspark import SparkContext

sc = SparkContext()
sqlContext = SQLContext(sc)
start_date = datetime.date(2020,1,1)

my_df = sc.parallelize([
        Row(id=1, created_at=datetime.datetime(2020, 1, 1), number=1,  metric=10),
        Row(id=1, created_at=datetime.datetime(2020, 1, 1), number=2,  metric=10),
        Row(id=1, created_at=datetime.datetime(2020, 1, 1), number=3,  metric=10),
        Row(id=2, created_at=datetime.datetime(2020, 1, 15), number=1,  metric=20),
        Row(id=2, created_at=datetime.datetime(2020, 1, 15), number=2,  metric=20),
        Row(id=3, created_at=datetime.datetime(2020, 7, 1), number=7,  metric=30),
        Row(id=3, created_at=datetime.datetime(2020, 7, 1), number=8,  metric=30),
        Row(id=3, created_at=datetime.datetime(2020, 7, 1), number=9,  metric=30),
        Row(id=3, created_at=datetime.datetime(2020, 7, 1), number=10, metric=30),
    ]).toDF()


# This doesn't work!
new_df = my_df.where("created_at + interval 7 days * number > '" + start_date.strftime("%Y-%m-%d") +"'")
# Neither does this!
new_df = my_df.filter(my_df.created_at + datetime.timedelta(days=my_df.number * 7)).date() > start_date.date()

There is a possible solution here that would require converting the date to a string, using the datetime library in python to convert the string to a datetime object, then performing the operation, but that seems crazy.

like image 334
TuringMachin Avatar asked Apr 12 '16 00:04

TuringMachin


1 Answers

Alright, I figured out a way forward using the expr and the built-in date_add functions.

from pyspark.sql.functions import expr, date_add
new_df = my_df.withColumn('test', expr('date_add(created_at, number*7)'))
filtered = new_df.filter(new_df.test > start_date)
filtered.show()

Would love some insight into how/why this works in a general way, though, if someone else wants to add on!

like image 163
TuringMachin Avatar answered Oct 22 '22 09:10

TuringMachin