Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fill Pyspark dataframe column null values with average value from same column

With a dataframe like this,

rdd_2 = sc.parallelize([(0,10,223,"201601"), (0,10,83,"2016032"),(1,20,None,"201602"),(1,20,3003,"201601"), (1,20,None,"201603"), (2,40, 2321,"201601"), (2,30, 10,"201602"),(2,61, None,"201601")])

df_data = sqlContext.createDataFrame(rdd_2, ["id", "type", "cost", "date"])
df_data.show()

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|null| 201602|
|  1|  20|3003| 201601|
|  1|  20|null| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|null| 201601|
+---+----+----+-------+

I need to fill the null values with the average of the existing values, with the expected result being

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

where 1128 is the average of the existing values. I need to do that for several columns.

My current approach is to use na.fill:

fill_values = {column: df_data.agg({column:"mean"}).flatMap(list).collect()[0] for column in df_data.columns if column not in ['date','id']}
df_data = df_data.na.fill(fill_values)

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

But this is very cumbersome. Any ideas?

like image 300
Ivan Avatar asked Jun 10 '16 13:06

Ivan


People also ask

How to replace null values with numeric values in pyspark dataframenafunctions?

PySpark fill (value:Long) signatures that are available in DataFrameNaFunctions is used to replace NULL values with numeric values either zero (0) or any constant value for all integer and long datatype columns of PySpark DataFrame or Dataset.

How to find maximum/minimum/average of particular column in pyspark Dataframe?

In this article, we are going to find the Maximum, Minimum, and Average of particular column in PySpark dataframe. For this, we will use agg () function. This function Compute aggregates and returns the result as DataFrame.

What is Dataframe fillna in pyspark?

PySpark In PySpark, DataFrame. fillna () or DataFrameNaFunctions.fill () is used to replace NULL/None values on all or selected multiple DataFrame columns with either zero (0), empty string, space, or any constant literal values.

How to find the average of Dataframe columns in Python?

For this, we will use agg () function. This function Compute aggregates and returns the result as DataFrame. Example 1: Python program to find the average of dataframe column


1 Answers

Well, one way or another you have to:

  • compute statistics
  • fill the blanks

It pretty much limits what you can really improve here, still:

  • replace flatMap(list).collect()[0] with first()[0] or structure unpacking
  • compute all stats with a single action
  • use built-in Row methods to extract dictionary

The final result could like this:

def fill_with_mean(df, exclude=set()): 
    stats = df.agg(*(
        avg(c).alias(c) for c in df.columns if c not in exclude
    ))
    return df.na.fill(stats.first().asDict())

fill_with_mean(df_data, ["id", "date"])

In Spark 2.2 or later you can also use Imputer. See Replace missing values with mean - Spark Dataframe.

like image 138
zero323 Avatar answered Oct 12 '22 19:10

zero323