Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark: Take average of a column after using filter function

I am using the following code to get the average age of people whose salary is greater than some threshold.

dataframe.filter(df['salary'] > 100000).agg({"avg": "age"})

the column age is numeric (float) but still I am getting this error.

py4j.protocol.Py4JJavaError: An error occurred while calling o86.agg. 
: scala.MatchError: age (of class java.lang.String)

Do you know any other way to obtain the avg etc. without using groupBy function and SQL queries.

like image 344
Harit Vishwakarma Avatar asked Sep 13 '15 14:09

Harit Vishwakarma


Video Answer


2 Answers

Aggregation function should be a value and a column name a key:

dataframe.filter(df['salary'] > 100000).agg({"age": "avg"})

Alternatively you can use pyspark.sql.functions:

from pyspark.sql.functions import col, avg

dataframe.filter(df['salary'] > 100000).agg(avg(col("age")))

It is also possible to use CASE .. WHEN

from pyspark.sql.functions import when

dataframe.select(avg(when(df['salary'] > 100000, df['age'])))
like image 191
zero323 Avatar answered Oct 23 '22 10:10

zero323


You can try this too:

dataframe.filter(df['salary'] > 100000).groupBy().avg('age')
like image 20
Ahmed Gehad Avatar answered Oct 23 '22 11:10

Ahmed Gehad