Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to filter rows for a specific aggregate with spark sql?

Normally all rows in a group are passed to an aggregate function. I would like to filter rows using a condition so that only some rows within a group are passed to an aggregate function. Such operation is possible with PostgreSQL. I would like to do the same thing with Spark SQL DataFrame (Spark 2.0.0).

The code could probably look like this:

val df = ... // some data frame
df.groupBy("A").agg(
  max("B").where("B").less(10), // there is no such method as `where` :(
  max("C").where("C").less(5)
)

So for a data frame like this:

| A | B | C |
|  1| 14|  4|
|  1|  9|  3|
|  2|  5|  6|

The result would be:

|A|max(B)|max(C)|
|1|    9|      4|
|2|    5|   null|

Is it possible with Spark SQL?

Note that in general any other aggregate function than max could be used and there could be multiple aggregates over the same column with arbitrary filtering conditions.

like image 606
Marcin Król Avatar asked Sep 26 '16 22:09

Marcin Król


People also ask

What is AGG () in Spark?

agg(Column expr, Column... exprs) Compute aggregates by specifying a series of aggregate columns.

How do you select rows in PySpark DataFrame?

Selecting rows using the filter() function The first option you have when it comes to filtering DataFrame rows is pyspark. sql. DataFrame. filter() function that performs filtering based on the specified conditions.

How do you check if a column contains a particular value in PySpark?

In Spark & PySpark, contains() function is used to match a column value contains in a literal string (matches on part of the string), this is mostly used to filter rows on DataFrame.


1 Answers

val df = Seq(
    (1,14,4),
    (1,9,3),
    (2,5,6)
  ).toDF("a","b","c")

val aggregatedDF = df.groupBy("a")
  .agg(
    max(when($"b" < 10, $"b")).as("MaxB"),
    max(when($"c" < 5, $"c")).as("MaxC")
  )

aggregatedDF.show
like image 76
user2682459 Avatar answered Nov 02 '22 22:11

user2682459