Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark: calculate mean, standard deviation and those values around the mean in one step

My raw data comes in a tabular format. It contains observations from different variables. Each observation with the variable name, the timestamp and the value at that time.

Variable [string], Time [datetime], Value [float]

The data is stored as Parquet in HDFS and loaded into a Spark Dataframe (df). From that dataframe.

Now I want to calculate default statistics like Mean, Standard Deviation and others for each variable. Afterwards, once the Mean has been retrieved, I want to filter/count those values for that variable that are closely around the Mean.

Due to the answer towards my other question, I came up with this code:

from pyspark.sql.window import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *

w1 = Window().partitionBy("Variable")
w2 = Window.partitionBy("Variable").orderBy("Time")

def stddev_pop_w(col, w):
    #Built-in stddev doesn't support windowing
    return sqrt(avg(col * col).over(w) - pow(avg(col).over(w), 2))

def isInRange(value, mean, stddev, radius):
    try:
        if (abs(value - mean) < radius * stddev):
            return 1
        else:
            return 0
    except AttributeError:
        return -1

delta = col("Time").cast("long") - lag("Time", 1).over(w2).cast("long")
#f = udf(lambda (value, mean, stddev, radius): abs(value - mean) < radius * stddev, IntegerType())
#f2 = udf(lambda value, mean, stddev: isInRange(value, mean, stddev, 2), IntegerType())
#f3 = udf(lambda value, mean, stddev: isInRange(value, mean, stddev, 3), IntegerType())

df_ = df_all \
    .withColumn("mean", mean("Value").over(w1)) \
    .withColumn("std_deviation", stddev_pop_w(col("Value"), w1)) \
    .withColumn("delta", delta) \
#    .withColumn("stddev_2", f2("Value", "mean", "std_deviation")) \
#    .withColumn("stddev_3", f3("Value", "mean", "std_deviation")) \

#df2.show(5, False)

Question: The last two commented-lines won't work. It will give an AttributeError because the incoming values for stddev and mean are null. I guess this happens because I'm referring to columns that are also just calculated on the fly and have no value at that moment. But is there a way to achieve that?

Currently I'm doing a second run like this:

df = df_.select("*", \
    abs(df_.Value - df_.mean).alias("max_deviation_mean"), \
    when(abs(df_.Value - df_.mean) < 2 * df_.std_deviation, 1).otherwise(1).alias("std_dev_mean_2"), \
    when(abs(df_.Value - df_.mean) < 3 * df_.std_deviation, 1).otherwise(1).alias("std_dev_mean_3"))
like image 313
Matthias Avatar asked Mar 12 '23 18:03

Matthias


2 Answers

The solution is to use the DataFrame.aggregateByKey function that aggregates the values per partition and node before shuffling that aggregate around the computing nodes where they are combined to one resulting value.

Pseudo-code looks like this. It is inspired by this tutorial, but it uses two instances of the StatCounter though we are summarizing two different statistics at once:

from pyspark.statcounter import StatCounter
# value[0] is the timestamp and value[1] is the float-value
# we are using two instances of StatCounter to sum-up two different statistics

def mergeValues(s1, v1, s2, v2):
    s1.merge(v1)
    s2.merge(v2)
    return

def combineStats(s1, s2):
    s1[0].mergeStats(s2[0])
    s1[1].mergeStats(s2[1])
    return
(df.aggregateByKey((StatCounter(), StatCounter()),
        (lambda s, values: mergeValues(s[0], values[0], s[1], values[1]),
        (lambda s1, s2: combineStats(s1, s2))
    .mapValues(lambda s: (  s[0].min(), s[0].max(), s[1].max(), s[1].min(), s[1].mean(), s[1].variance(), s[1].stddev,() s[1].count()))
    .collect())
like image 56
Matthias Avatar answered Apr 28 '23 10:04

Matthias


This cannot work because when you execute

from pyspark.sql.functions import *

you shadow built-in abs with pyspark.sql.functions.abs which expects a column not a local Python value as an input.

Also UDF you created doesn't handle NULL entries.

  • Don't use import * unless you're aware of what exactly is imported. Instead alias

    from pyspark.sql.functions import abs as abs_
    

    or import module

    from pyspark.sql import functions as sqlf
    
    sqlf.col("x")
    
  • Always check input inside UDF or even better avoid UDFs unless necessary.

like image 38
zero323 Avatar answered Apr 28 '23 10:04

zero323