Question in Brief:
For a more direct query, i want to run over all the rows sequentially, and assign some values to some variables (a, b, c), based on certain conditions for the specific row, then i would assign the value of 1 of these variables into a column of that particular row.
Detailed:
I want to update a column value in the data frame in spark. The update will be conditional, where in I will run a loop on row and update a column based on the values of the other columns of that row.
I tried to use withColumn approach but got error. Please suggest any other approach. The resolution of the withColumn approach will also be of great help.
Table:
var table1 = Seq((11, 25, 2, 0), (42, 20, 10, 0)).toDF("col_1", "col_2", "col_3", "col_4")
table1.show()
Schema:
+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|
+-----+-----+-----+-----+
| 11| 25| 2| 0|
| 42| 20| 10| 0|
+-----+-----+-----+-----+
I have tried 2 approaches here:
In the below code, the variables initialised at different locations need to be placed in this way only, as per the conditions
Code:
for(i <- table1.rdd.collect()) {
if(i.getAs[Int]("col_1") > 0) {
var adj_a = 0
var adj_c = 0
if(i.getAs[Int]("col_1") > (i.getAs[Int]("col_2") + i.getAs[Int]("col_3"))) {
if(i.getAs[Int]("col_1") < i.getAs[Int]("col_2")) {
adj_a = 10
adj_c = 2
}
else {
adj_a = 5
}
}
else {
adj_c = 1
}
adj_c = adj_c + i.getAs[Int]("col_2")
table1.withColumn("col_4", adj_c)
//i("col_4") = adj_c
}
}
Error in 1st case:
table1.withColumn("col_4", adj_c)
<console>:80: error: type mismatch;
found : Int
required: org.apache.spark.sql.Column
table1.withColumn("col_4", adj_c)
^
I also tried to use col(adj_c) here, but it started failing with
<console>:80: error: type mismatch;
found : Int
required: String
table1.withColumn("col_4", col(adj_c))
^
Error in 2nd case:
(i("col_4") = adj_c)
<console>:81: error: value update is not a member of org.apache.spark.sql.Row
i("col_4") = adj_c
^
I want the output table to be:
+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|
+-----+-----+-----+-----+
| 11| 25| 2| 1|
| 42| 20| 10| 5|
+-----+-----+-----+-----+
Please suggest the possible solutions and revert in case of any doubt with the question.
Please help me with this as i am stuck with issue. Any kind of suggestion will be very helpful.
You should use a when
function instead of such complicated syntax, also there is no need for an explicit loop, Spark handles it itself. When you perform a withColumn
it is applied to each row
table1.withColumn("col_4", when($"col_1" > $"col_2" + $"col_3", 5).otherwise(1)).show
QUICK TEST:
INPUT
table1.show
-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|
+-----+-----+-----+-----+
| 11| 25| 2| 0|
| 42| 20| 10| 0|
+-----+-----+-----+-----+
OUTPUT
table1.withColumn("col_4", when($"col_1" > $"col_2" + $"col_3", lit(5)).otherwise(1)).show
+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|
+-----+-----+-----+-----+
| 11| 25| 2| 1|
| 42| 20| 10| 5|
+-----+-----+-----+-----+
UDF can be used with any custom logic for caluclate column value, like:
val calculateCol4 = (col_1:Int, col_2:Int, col_3:Int) =>
if (col_1 > 0) {
var adj_a = 0
var adj_c = 0
if (col_1 > col_2 + col_3) {
if (col_1 < col_2) {
adj_a = 10
adj_c = 2
}
else {
adj_a = 5
}
}
else {
adj_c = 1
}
println("adj_c: "+adj_c)
adj_c = adj_c + col_2
// added for return correct result
adj_c
}
// added for return correct result
else 0
val col4UDF = udf(calculateCol4)
table1.withColumn("col_4",col4UDF($"col_1", $"col_2", $"col_3"))
using spark.sql, more easy to read and understand -
scala> var table1 = Seq((11, 25, 2, 0), (42, 20, 10, 0)).toDF("col_1", "col_2", "col_3", "col_4")
table1: org.apache.spark.sql.DataFrame = [col_1: int, col_2: int ... 2 more fields]
scala> table1.show()
+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|
+-----+-----+-----+-----+
| 11| 25| 2| 0|
| 42| 20| 10| 0|
+-----+-----+-----+-----+
scala> table1.createOrReplaceTempView("table1")
scala> val result = spark.sql(s""" select col_1,
| col_2,
| col_3,
| CASE WHEN col_1 > (col_2 + col_3)
| THEN 5
| ELSE 1
| END as col_4
| from table1 """)
result: org.apache.spark.sql.DataFrame = [col_1: int, col_2: int ... 2 more fields]
scala> result.show(false)
+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|
+-----+-----+-----+-----+
|11 |25 |2 |1 |
|42 |20 |10 |5 |
+-----+-----+-----+-----+
Hope this is helpful.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With