I have a spark dataframe df with schema as such:
[id:string, label:string, tags:string]
id | label | tag
---|-------|-----
1 | h | null
1 | w | x
1 | v | null
1 | v | x
2 | h | x
3 | h | x
3 | w | x
3 | v | null
3 | v | null
4 | h | null
4 | w | x
5 | w | x
(h,w,v are labels. x can be any non-empty values)
For each id, there is at most one label "h" or "w", but there might be multiple "v". I would like to select all the ids that satisfies following conditions:
Each id has: 1. one label "h" and its tag = null, 2. one label "w" and its tag != null, 3. at least one label "v" for each id.
I am thinking that I need to create three columns checking each above conditions. And then I need to do a group by "id".
val hCheck = (label: String, tag: String) => {if (label=="h" && tag==null) 1 else 0}
val udfHCheck = udf(hCheck)
val wCheck = (label: String, tag: String) => {if (label=="w" && tag!=null) 1 else 0}
val udfWCheck = udf(wCheck)
val vCheck = (label: String) => {if (label==null) 1 else 0}
val udfVCheck = udf(vCheck)
dfx = df.withColumn("hCheck", udfHCheck(col("label"), col("tag")))
.withColumn("wCheck", udfWCheck(col("label"), col("tag")))
.withColumn("vCheck", udfVCheck(col("label")))
.select("id","hCheck","wCheck","vCheck")
.groupBy("id")
Somehow I need to group three columns {"hCheck","wCheck","vCheck"} into vector of list [x,0,0],[0,x,0],[0,0,x]. And check if these vector contain all three {[1,0,0],[0,1,0],[0,0,1]}
I have not been able to solve this problem yet. And there might be a better approach than this one. Hope someone can give me suggestions. Thanks
To convert the three checks to vectors you can do: Specifically you can do:
val df1 = df.withColumn("hCheck", udfHCheck(col("label"), col("tag")))
.withColumn("wCheck", udfWCheck(col("label"), col("tag")))
.withColumn("vCheck", udfVCheck(col("label")))
.select($"id",array($"hCheck",$"wCheck",$"vCheck").as("vec"))
Next the groupby returns a grouped object on which you need to perform aggregations. Specifically to get all the vectors you should do something like:
.groupBy("id").agg(collect_list($"vec"))
Also you do not need udfs for the various checks. You can do it with column semantics. For example udfHCheck can be written as:
with($"label" == lit("h") && tag.isnull 1).otherwise(0)
BTW, you said you wanted a label 'v' for each but in vcheck you just check if the label is null.
Update: Alternative solution
Upon looking on this question again, I would do something like this:
val grouped = df.groupBy("id", "label").agg(count("$label").as("cnt"), first($"tag").as("tag"))
val filtered1 = grouped.filter($"label" === "v" || $"cnt" === 1)
val filtered2 = filtered.filter($"label" === "v" || ($"label" === "h" && $"tag".isNull) || ($"label" === "w" && $"tag".isNotNull))
val ids = filtered2.groupBy("id").count.filter($"count" === 3)
The idea is that first we groupby BOTH id and label so we have information on the combination. The information we collect is how many values (cnt) and the first element (doesn't matter which).
Now we do two filtering steps: 1. we need exactly one h and one w and any number of v so the first filter gets us these cases. 2. we make sure all the rules are met for each of the cases.
Now we have only combinations of id and label which match the rules so in order for the id to be legal we need to have exactly three instances of label. This leads to the second groupby which simply counts the number of labels which matched the rules. We need exactly three to be legal (i.e. matched all the rules).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With