We have dataframe like below :
+------+--------------------+
| Flag | value|
+------+--------------------+
|1 |5 |
|1 |4 |
|1 |3 |
|1 |5 |
|1 |6 |
|1 |4 |
|1 |7 |
|1 |5 |
|1 |2 |
|1 |3 |
|1 |2 |
|1 |6 |
|1 |9 |
+------+--------------------+
After normal cumsum we get this.
+------+--------------------+----------+
| Flag | value|cumsum |
+------+--------------------+----------+
|1 |5 |5 |
|1 |4 |9 |
|1 |3 |12 |
|1 |5 |17 |
|1 |6 |23 |
|1 |4 |27 |
|1 |7 |34 |
|1 |5 |39 |
|1 |2 |41 |
|1 |3 |44 |
|1 |2 |46 |
|1 |6 |52 |
|1 |9 |61 |
+------+--------------------+----------+
Now what we want is for cumsum to reset when specific condition is set for ex. when it crosses 20.
Below is expected output:
+------+--------------------+----------+---------+
| Flag | value|cumsum |expected |
+------+--------------------+----------+---------+
|1 |5 |5 |5 |
|1 |4 |9 |9 |
|1 |3 |12 |12 |
|1 |5 |17 |17 |
|1 |6 |23 |23 |
|1 |4 |27 |4 | <-----reset
|1 |7 |34 |11 |
|1 |5 |39 |16 |
|1 |2 |41 |18 |
|1 |3 |44 |21 |
|1 |2 |46 |2 | <-----reset
|1 |6 |52 |8 |
|1 |9 |61 |17 |
+------+--------------------+----------+---------+
This is how we are calculating the cumulative sum.
win_counter = Window.partitionBy("flag")
df_partitioned = df_partitioned.withColumn('cumsum',F.sum(F.col('value')).over(win_counter))
There are two ways I've found to solve it without udf
:
from pyspark.sql.window import Window
import pyspark.sql.functions as f
df = spark.createDataFrame([
(1, 5), (1, 4), (1, 3), (1, 5), (1, 6), (1, 4),
(1, 7), (1, 5), (1, 2), (1, 3), (1, 2), (1, 6), (1, 9)
], schema='Flag int, value int')
w = (Window
.partitionBy('flag')
.orderBy(f.monotonically_increasing_id())
.rowsBetween(Window.unboundedPreceding, Window.currentRow))
df = df.withColumn('values', f.collect_list('value').over(w))
expr = "AGGREGATE(values, 0, (acc, el) -> IF(acc < 20, acc + el, el))"
df = df.select('Flag', 'value', f.expr(expr).alias('cumsum'))
df.show(truncate=False)
df = spark.createDataFrame([
(1, 5), (1, 4), (1, 3), (1, 5), (1, 6), (1, 4),
(1, 7), (1, 5), (1, 2), (1, 3), (1, 2), (1, 6), (1, 9)
], schema='Flag int, value int')
def cumsum_by_flag(rows):
cumsum, reset = 0, False
for row in rows:
if reset:
cumsum = row.value
reset = False
else:
cumsum += row.value
reset = cumsum > 20
yield row.value, cumsum
def unpack(value):
flag = value[0]
value, cumsum = value[1]
return flag, value, cumsum
rdd = df.rdd.keyBy(lambda row: row.Flag)
rdd = (rdd
.groupByKey()
.flatMapValues(cumsum_by_flag)
.map(unpack))
df = rdd.toDF('Flag int, value int, cumsum int')
df.show(truncate=False)
Output:
+----+-----+------+
|Flag|value|cumsum|
+----+-----+------+
|1 |5 |5 |
|1 |4 |9 |
|1 |3 |12 |
|1 |5 |17 |
|1 |6 |23 |
|1 |4 |4 |
|1 |7 |11 |
|1 |5 |16 |
|1 |2 |18 |
|1 |3 |21 |
|1 |2 |2 |
|1 |6 |8 |
|1 |9 |17 |
+----+-----+------+
It's probably best to do with pandas_udf
here.
from pyspark.sql.functions import pandas_udf, PandasUDFType
pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
df = spark.createDataFrame(pdf)
df = df.withColumn('cumsum', F.lit(math.inf))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def _calc_cumsum(pdf):
pdf.sort_values(by=['id'], inplace=True, ascending=True)
cumsums = []
prev = None
reset = False
for v in pdf['value'].values:
if prev is None:
cumsums.append(v)
prev = v
else:
prev = prev + v if not reset else v
cumsums.append(prev)
reset = True if prev >= 20 else False
pdf['cumsum'] = cumsums
return pdf
df = df.groupby('flag').apply(_calc_cumsum)
df.show()
the results:
+----+---+-----+------+
|flag| id|value|cumsum|
+----+---+-----+------+
| 1| 0| 5| 5.0|
| 1| 1| 4| 9.0|
| 1| 2| 3| 12.0|
| 1| 3| 5| 17.0|
| 1| 4| 6| 23.0|
| 1| 5| 4| 4.0|
| 1| 6| 7| 11.0|
| 1| 7| 5| 16.0|
| 1| 8| 2| 18.0|
| 1| 9| 3| 21.0|
| 1| 10| 2| 2.0|
| 1| 11| 6| 8.0|
| 1| 12| 9| 17.0|
+----+---+-----+------+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With