Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala: Functional aggregation of Seq[T] elements => Seq[Seq[T]] (preserving order)

I'd like to aggregate compatible elements in a sequence, i.e. transform a Seq[T] into a Seq[Seq[T]] where elements in each subsequence are compatible with each other while the original seq order is preserved, e.g. from

case class X(i: Int, n: Int) {
  def canJoin(that: X): Boolean = this.n == that.n
  override val toString = i + "." + n
}
val xs = Seq(X(1, 1), X(2, 3), X(3, 3), X(4, 3), X(5, 1), X(6, 2), X(7, 2), X(8, 1))
/* xs = List(1.1, 2.3, 3.3, 4.3, 5.1, 6.2, 7.2, 8.1) */

want to get

val js = join(xs)
/* js = List(List(1.1), List(2.3, 3.3, 4.3), List(5.1), List(6.2, 7.2), List(8.1)) */

I've tried to do this in a functional way, but I got stuck halfway:

Doing with a while loop

def split(seq: Seq[X]): (Seq[X], Seq[X]) = seq.span(_ canJoin seq.head)
def join(seq: Seq[X]): Seq[Seq[X]] = {
  var pp = Seq[Seq[X]]()
  var s = seq
  while (!s.isEmpty) {
    val (p, r) = split(s)
    pp :+= p
    s = r
  }
  pp
}

With the split I'm satisfied, but the join seems to be a little bit too long.

In my opinion, that's a standard task. That leads me to the question:

  1. Are there functions in the collections library that makes it possible to reduce code size?
  2. Or is there perhaps a different approach to solve the task? Especially another approach than in Rewriting a sequence by partitioning and collapsing?

Replacing while loop with tail recursion

def join(xs: Seq[X]): Seq[Seq[X]] = {
  @annotation.tailrec
  def jointr(pp: Seq[Seq[X]], rem: Seq[X]): Seq[Seq[X]] = {
    val (p, r) = split(rem)
    val pp2 = pp :+ p
    if (r.isEmpty) pp2 else jointr(pp2, r)
  }
  jointr(Seq(), xs)
}
like image 959
binuWADa Avatar asked Nov 10 '11 20:11

binuWADa


2 Answers

def join(seq: Seq[X]): Seq[Seq[X]] = {
  if (seq.isEmpty) return Seq()
  val (p,r) = split(seq)
  Seq(p) ++ join(r)
}
like image 158
Peter Schmitz Avatar answered Sep 25 '22 17:09

Peter Schmitz


Here is foldLeft version:

def join(seq: Seq[X]) = xs.reverse.foldLeft(Nil: List[List[X]]) {
    case ((top :: group) :: rest, x) if x canJoin top => 
        (x :: top :: group) :: rest
    case (list, x) => (x :: Nil) :: list
} 

and foldRight version (you don't need reverse the list in this case):

def join(seq: Seq[X]) = xs.foldRight(Nil: List[List[X]]) {
    case (x, (top :: group) :: rest) if x canJoin top => 
        (x :: top :: group) :: rest
    case (x, list) => (x :: Nil) :: list
} 
like image 20
tenshi Avatar answered Sep 24 '22 17:09

tenshi