Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Functional way to take element in a list until a limit in Scala

The aim of the method is to take elements in a list until a limit is reached.

e.g.

I've come up with 2 different implementations

def take(l: List[Int], limit: Int): List[Int] = {
  var sum = 0
  l.takeWhile { e =>
    sum += e
    sum <= limit
  }
}

It is straightforward, but a mutable state is used.

def take(l: List[Int], limit: Int): List[Int] = {
  val summed = l.toStream.scanLeft(0) { case (e, sum) => sum + e }
  l.take(summed.indexWhere(_ > limit) - 1)
}

It seems cleaner, but it's more verbose and perhaps less memory efficient because a stream is needed.

Is there a better way ?

like image 661
Yann Moisan Avatar asked Apr 02 '17 08:04

Yann Moisan


2 Answers

You could also do that in a single pass with a fold:

  def take(l: List[Int], limit: Int): List[Int] =
    l.fold((List.empty[Int], 0)) { case ((res, acc), next) =>
      if (acc + next > limit)
        (res, limit)
      else
        (next :: res, next + acc)
    }

Because the standard lists aren't lazy, and neither is fold, this will always traverse the entire list. One alternative would be to use cats' iteratorFoldM instead for an implementation that short circuits once the limit is reached.

You could also write the short circuiting fold directly using tail recursion, something along those lines:

def take(l: List[Int], limit: Int): List[Int] = {
  @annotation.tailrec
  def take0(list: List[Int], accList: List[Int], accSum: Int) : List[Int] =
    list match {
      case h :: t if accSum + h < limit =>  
        take0(t, h :: accList, h + accSum)
      case _ => accList
    }
  take0(l, Nil, 0).reverse
}

Note that this second solution might be faster, but also less elegant as it requires additional effort to prove that the implementation terminates, something obvious when using a fold.

like image 100
OlivierBlanvillain Avatar answered Nov 14 '22 22:11

OlivierBlanvillain


The first way is perfectly fine as the result of your function is still perfectly immutable.

On a side note, this is actually how many functions of the scala collection library are implemented, they create a mutable builder for efficiency and return an immutable collection out of it.

like image 39
Frederic A. Avatar answered Nov 14 '22 21:11

Frederic A.