Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sequences in Spark dataframe

I have dataframe in Spark. Looks like this:

+-------+----------+-------+
|  value|     group|     ts|
+-------+----------+-------+
|      A|         X|      1|
|      B|         X|      2|
|      B|         X|      3|
|      D|         X|      4|
|      E|         X|      5|
|      A|         Y|      1|
|      C|         Y|      2|
+-------+----------+-------+

Endgoal: I'd like to find how many sequences A-B-E (a sequence is just a list of subsequent rows) there are. With the added constraint that subsequent parts of the sequence can be maximum n rows apart. Let's consider for this example that n is 2.

Consider group X. In this case there is exactly 1 D between B and E (multiple consecutive Bs are ignored). Which means B and E are 1 row apart and thus there is a sequence A-B-E

I have thought about using collect_list(), creating a string (like DNA) and using substring search with regex. But I was wondering if there's a more elegant distributed way, perhaps using window functions?

Edit:

Note that the provided dataframe is just an example. The real dataframe (and thus groups) can be arbitrary long.

like image 611
Tim Avatar asked Oct 21 '16 13:10

Tim


1 Answers

Edited to answer @Tim's comment + fix patterns of the type "AABE"

Yep, using a window function helps, but I created an id to have an ordering:

val df = List(
  (1,"A","X",1),
  (2,"B","X",2),
  (3,"B","X",3),
  (4,"D","X",4),
  (5,"E","X",5),
  (6,"A","Y",1),
  (7,"C","Y",2)
).toDF("id","value","group","ts")

import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy('group).orderBy('id)

Then lag will collect what is needed, but a function is required to generate the Column expression (note the split to eliminate double counting of "AABE". WARNING: this rejects patterns of the type "ABAEXX"):

def createSeq(m:Int) = split(
  concat(
    (1 to 2*m)
      .map(i => coalesce(lag('value,-i).over(w),lit("")))
  :_*),"A")(0)


val m=2
val tmp = df
  .withColumn("seq",createSeq(m))

+---+-----+-----+---+----+
| id|value|group| ts| seq|
+---+-----+-----+---+----+
|  6|    A|    Y|  1|   C|
|  7|    C|    Y|  2|    |
|  1|    A|    X|  1|BBDE|
|  2|    B|    X|  2| BDE|
|  3|    B|    X|  3|  DE|
|  4|    D|    X|  4|   E|
|  5|    E|    X|  5|    |
+---+-----+-----+---+----+

Because of the poor set of collection functions available in the Column API, avoiding regex altogether is much easier using a UDF

def patternInSeq(m: Int) = udf((str: String) => {
  var notFound = str
    .split("B")
    .filter(_.contains("E"))
    .filter(_.indexOf("E") <= m)
    .isEmpty
  !notFound
})

val res = tmp
  .filter(('value === "A") && (locate("B",'seq) > 0))
  .filter(locate("B",'seq) <= m && (locate("E",'seq) > 1))
  .filter(patternInSeq(m)('seq))
  .groupBy('group)
  .count
res.show

+-----+-----+
|group|count|
+-----+-----+
|    X|    1|
+-----+-----+

Generalisation (out of scope)

If you want to generalise it sequence of letter that are longer, the question has to be generalised. It could be trivial, but in this case a pattern of the type ("ABAE") should be rejected (see comments). So the easiest way to generalise is to have a pair-wise rule as in the following implementation (I added a group "Z" to illustrate the behaviour of this algo)

val df = List(
  (1,"A","X",1),
  (2,"B","X",2),
  (3,"B","X",3),
  (4,"D","X",4),
  (5,"E","X",5),
  (6,"A","Y",1),
  (7,"C","Y",2),
  ( 8,"A","Z",1),
  ( 9,"B","Z",2),
  (10,"D","Z",3),
  (11,"B","Z",4),
  (12,"E","Z",5)
).toDF("id","value","group","ts")

First we define the logic for a pair

import org.apache.spark.sql.DataFrame
def createSeq(m:Int) = array((0 to 2*m).map(i => coalesce(lag('value,-i).over(w),lit(""))):_*)
def filterPairUdf(m: Int, t: (String,String)) = udf((ar: Array[String]) => {
  val (a,b) = t
  val foundAt = ar
    .dropWhile(_ != a)
    .takeWhile(_ != a)
    .indexOf(b)
  foundAt != -1 && foundAt <= m
})

Then we define a function that applies this logic is applied iteratively on the dataframe

def filterSeq(seq: List[String], m: Int)(df: DataFrame): DataFrame = {
  var a = seq(0)
  seq.tail.foldLeft(df){(df: DataFrame, b: String) => {
    val res  = df.filter(filterPairUdf(m,(a,b))('seq))
    a = b
    res
  }}
}

A simplification and optimisation is obtained because we first filter on sequence beginning with the first character

val m = 2
val tmp = df
  .filter('value === "A") // reduce problem
  .withColumn("seq",createSeq(m))

scala> tmp.show()
+---+-----+-----+---+---------------+
| id|value|group| ts|            seq|
+---+-----+-----+---+---------------+
|  6|    A|    Y|  1|   [A, C, , , ]|
|  8|    A|    Z|  1|[A, B, D, B, E]|
|  1|    A|    X|  1|[A, B, B, D, E]|
+---+-----+-----+---+---------------+

val res = tmp.transform(filterSeq(List("A","B","E"),m))

scala> res.show()
+---+-----+-----+---+---------------+
| id|value|group| ts|            seq|
+---+-----+-----+---+---------------+
|  1|    A|    X|  1|[A, B, B, D, E]|
+---+-----+-----+---+---------------+

(transform is a simple sugar-coating of DataFrame => DataFrame transformation)

res
  .groupBy('group)
  .count
  .show

+-----+-----+
|group|count|
+-----+-----+
|    X|    1|
+-----+-----+

As I said, there are different way to generalise the "resetting rules" when scanning a sequence,but this exemple hopefully helps in the implementation of more complex ones.

like image 173
Wilmerton Avatar answered Oct 20 '22 00:10

Wilmerton