Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark Dataframe Group by having New Indicator Column

I need to group by "KEY" Column and need to check whether "TYPE_CODE" column has both "PL" and "JL" values , if so then i need to add a Indicator Column as "Y" else "N"

Example :

    //Input Values
    val values = List(List("66","PL") ,
    List("67","JL") , List("67","PL"),List("67","PO"),
    List("68","JL"),List("68","PO")).map(x =>(x(0), x(1)))

    import spark.implicits._
    //created a dataframe
    val cmc = values.toDF("KEY","TYPE_CODE")

    cmc.show(false)
    ------------------------
    KEY |TYPE_CODE  |
    ------------------------
    66  |PL |
    67  |JL |
    67  |PL |
    67  |PO |
    68  |JL |
    68  |PO |
    -------------------------

Expected Output :

For each "KEY", If it has "TYPE_CODE" has both PL & JL then Y else N

    -----------------------------------------------------
    KEY |TYPE_CODE  | Indicator
    -----------------------------------------------------
    66  |PL         | N
    67  |JL         | Y
    67  |PL         | Y
    67  |PO         | Y
    68  |JL         | N
    68  |PO         | N
    ---------------------------------------------------

For example, 67 has both PL & JL - So "Y" 66 has only PL - So "N" 68 has only JL - So "N"

like image 917
RaAm Avatar asked Jan 29 '23 14:01

RaAm


2 Answers

One option:

1) collect TYPE_CODE as list;

2) check if it contains the specific strings;

3) then flatten the list with explode:

(cmc.groupBy("KEY")
    .agg(collect_list("TYPE_CODE").as("TYPE_CODE"))
    .withColumn("Indicator", 
        when(array_contains($"TYPE_CODE", "PL") && array_contains($"TYPE_CODE", "JL"), "Y").otherwise("N"))
    .withColumn("TYPE_CODE", explode($"TYPE_CODE"))).show
+---+---------+---------+
|KEY|TYPE_CODE|Indicator|
+---+---------+---------+
| 68|       JL|        N|
| 68|       PO|        N|    
| 67|       JL|        Y|
| 67|       PL|        Y|
| 67|       PO|        Y|
| 66|       PL|        N|
+---+---------+---------+
like image 144
Psidom Avatar answered Feb 02 '23 08:02

Psidom


Another option:

  1. Group by KEY and use agg to create two separate indicator columns (one for JL and on for PL), then calculate the combined indicator

  2. join with the original DataFrame

Altogether:

val indicators = cmc.groupBy("KEY").agg(
  sum(when($"TYPE_CODE" === "PL", 1).otherwise(0)) as "pls",
  sum(when($"TYPE_CODE" === "JL", 1).otherwise(0)) as "jls"
).withColumn("Indicator", when($"pls" > 0 && $"jls" > 0, "Y").otherwise("N"))

val result = cmc.join(indicators, "KEY")
  .select("KEY", "TYPE_CODE", "Indicator")

This might be slower than @Psidom's answer, but might be safer - collect_list might be problematic if you have a huge number of matches for a specific key (that list would have to be stored in a single worker's memory).

EDIT:

In case the input is known to be unique (i.e. JL / PL would only appear once per key, at most), indicators could be created using simple count aggregation, which is (arguably) easier to read:

val indicators = cmc
  .where($"TYPE_CODE".isin("PL", "JL"))
  .groupBy("KEY").count()
  .withColumn("Indicator", when($"count" === 2, "Y").otherwise("N"))
like image 42
Tzach Zohar Avatar answered Feb 02 '23 10:02

Tzach Zohar