Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark Dataframe Imputations -- Replace Unknown & Missing Values with Column Mean based on specified condition

Given a Spark dataframe, I would like to compute a column mean based on the non-missing and non-unknown values for that column. I would then like to take this mean and use it to replace the column's missing & unknown values.

For example, assuming I'm working with a:

  • Dataframe named df, where each record represents one individual and all columns are integer or numeric
  • Column named age (ages for each record)
  • Column named missing_age (which equals 1 if that individual has no age, 0 otherwise)
  • Column named unknown_age (which equals 1 if that individual has unknown age, 0 otherwise)

Then I can compute this mean as shown below.

calc_mean = df.where((col("unknown_age") == 0) & (col("missing_age") == 0))
.agg(avg(col("age")))

OR via SQL and windows functions,

mean_compute = hiveContext.sql("select avg(age) over() as mean from df 
where missing_age = 0 and unknown_age = 0")

I don't want to use SQL/windows functions if I can help it. My challenge has been taking this mean and replacing the unknown/missing values with it using non-SQL methods.

I've tried using when(), where(), replace(), withColumn, UDFs, and combinations... Regardless of what I do, I either get errors or the results aren't what I expect. Here's an example of one of many things I've tried that didn't work.

imputed = df.when((col("unknown_age") == 1) | (col("missing_age") == 1),
calc_mean).otherwise("age")

I've scoured the web, but haven't found similar imputation type questions so any help is much appreciated. It could be something very simple that I've missed.

A side note -- I'm trying to apply this code to all columns in the Spark Dataframe that don't have unknown_ or missing_ in the column names. Can I just wrap the Spark related code in a Python 'for loop' and loop through all of the applicable columns to do this?

UPDATE:

Also figured out how to loop through columns... Here's an example.

for x in df.columns:
    if 'unknown_' not in x and 'missing_' not in x:
        avg_compute = df.where(df['missing_' + x] != 1).agg(avg(x)).first()[0]
        df = df.withColumn(x + 'mean_miss_imp', when((df['missing_' + x] == 1), 
        avg_compute).otherwise(df[x]))
like image 477
midnightfalcon Avatar asked May 24 '16 22:05

midnightfalcon


People also ask

How do you replace missing values in PySpark?

In PySpark, DataFrame. fillna() or DataFrameNaFunctions. fill() is used to replace NULL/None values on all or selected multiple DataFrame columns with either zero(0), empty string, space, or any constant literal values.

How do you handle missing data in PySpark DataFrame?

Replacing the Missing Values By creating imputed columns, we will create columns which will consist of values that fill the missing value by taking a statistical method such as mean/median of the original columns to fill the missing value. Thanks, and keep looking out for more PySpark content.

How do I change the null value in Spark DataFrame?

In Spark, fill() function of DataFrameNaFunctions class is used to replace NULL values on the DataFrame column with either with zero(0), empty string, space, or any constant literal values.

How do I remove NULL values in PySpark?

In order to remove Rows with NULL values on selected columns of Spark DataFrame, use drop(columns:Seq[String]) or drop(columns:Array[String]). To these functions pass the names of the columns you wanted to check for NULL values to delete rows.


1 Answers

If age for unknown or missing is some value:

from pyspark.sql.functions import col, avg, when

df = sc.parallelize([
    (10, 0, 0), (20, 0, 0), (-1, 1, 0), (-1, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])

avg_age = df.where(
    (col("unknown_age") != 1) & (col("missing_age") != 1)
).agg(avg("age")).first()[0]

df.withColumn("age_imp", when(
    (col("unknown_age") == 1) | (col("missing_age") == 1), avg_age
).otherwise(col("age")))

If age for unknown or missing is NULL you can simplify this to:

df = sc.parallelize([
    (10, 0, 0), (20, 0, 0), (None, 1, 0), (None, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])

df.na.fill(df.na.drop().agg(avg("age")).first()[0], ["age"])
like image 101
zero323 Avatar answered Sep 28 '22 16:09

zero323