Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Condition on rows content of dataframe in Spark scala

I have the following dataframe:

+--------+---------+------+
|  value1| value2  |value3|
+--------+---------+------+
|   a    |  2      |   3  |
+--------+---------+------+
|   b    |  5      |   4  |
+--------+---------+------+
|   b    |  5      |   4  |
+--------+---------+------+
|   c    |  3      |   4  |
+--------+---------+------+

I would like to put the result of value2/value3 of the rows when value1=b, then add it for all the rows (even the rows that do not belong to b) in a new field with name of "result". It means that another column must be added to the dataframe. For example for all the rows, the result of 5/4 (I choose it since, it is for b) should be added to dataframe. I know that, I should use this code:

 val dataframe_new = Dataframe.withColumn("result", $"value1" / $"value2")
 Dataframe.show()

However how can I put the condition in such a way that, it added it to all rows. The output should be like below:

+---+---+---+------+
| v1| v2| v3|result|
+---+---+---+------+
|  a|  2|  3|  1.25|
|  b|  5|  4|  1.25|
|  b|  5|  4|  1.25|
|  c|  3|  4|  1.25|
+---+---+---+------+

Can you help me? Thanks in advance.

like image 472
Queen Avatar asked Jan 03 '23 18:01

Queen


2 Answers

You just need to use when :

scala> val df = Seq(("a",2,3),("b",5,4),("b",5,4),("c",3,4)).toDF("v1","v2","v3")
df: org.apache.spark.sql.DataFrame = [v1: string, v2: int ... 1 more field]

scala> df.withColumn("result", when($"v1" === "b" , ($"v2"/$"v3"))).show
+---+---+---+------+
| v1| v2| v3|result|
+---+---+---+------+
|  a|  2|  3|  null|
|  b|  5|  4|  1.25|
|  b|  5|  4|  1.25|
|  c|  3|  4|  null|
+---+---+---+------+

You can embed multiple when as followed :

scala> df.withColumn("result", when($"v1" === "b" , ($"v2"/$"v3")).
     |    otherwise(when($"v1" === "a", $"v3"/$"v2"))).show
+---+---+---+------+
| v1| v2| v3|result|
+---+---+---+------+
|  a|  2|  3|   1.5|
|  b|  5|  4|  1.25|
|  b|  5|  4|  1.25|
|  c|  3|  4|  null|
+---+---+---+------+

EDIT: It seems that you need something else where the condition for v1 has always the same values v2 and v3 which allows us doing the following :

With Spark 2+ :

scala> val res = df.filter($"v1" === lit("b")).distinct.select($"v2"/$"v3").as[Double].head
res: Double = 1.25

Before Spark <2 :

scala> val res = df.filter($"v1" === lit("b")).distinct.withColumn("result",$"v2"/$"v3").rdd.map(_.getAs[Double]("result")).collect()(0)
res: Double = 1.25                                                              

scala> df.withColumn("v4", lit(res)).show
+---+---+---+----+
| v1| v2| v3|  v4|
+---+---+---+----+
|  a|  2|  3|1.25|
|  b|  5|  4|1.25|
|  b|  5|  4|1.25|
|  c|  3|  4|1.25|
+---+---+---+----+
like image 96
eliasah Avatar answered Jan 13 '23 10:01

eliasah


The answer is almost similar to eliasah but with different flavors. I am writing it so that others can benefit from this approach too

import sqlContext.implicits._

val df = Seq(
  ("a", 2, 3),
  ("b", 5, 4),
  ("b", 5, 4),
  ("c", 3, 4)
).toDF("value1", "value2", "value3")

should have

+------+------+------+
|value1|value2|value3|
+------+------+------+
|a     |2     |3     |
|b     |5     |4     |
|b     |5     |4     |
|c     |3     |4     |
+------+------+------+

And

df.withColumn("result", lit(data.filter($"value1" === "b").select($"value2"/$"value3").first.get(0)))

should generate the output

+------+------+------+------+
|value1|value2|value3|result|
+------+------+------+------+
|a     |2     |3     |1.25  |
|b     |5     |4     |1.25  |
|b     |5     |4     |1.25  |
|c     |3     |4     |1.25  |
+------+------+------+------+
like image 43
Ramesh Maharjan Avatar answered Jan 13 '23 10:01

Ramesh Maharjan