A direct cut and paste of the following algorithm:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys
case (_, Nil) => xs
case (x :: xs1, y :: ys1) =>
if (less(x, y)) x :: merge(xs1, ys)
else y :: merge(xs, ys1)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs))
}
}
causes a StackOverflowError on 5000 long lists.
Is there any way to optimize this so that this doesn't occur?
It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.
The latter solution goes like this:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys.reverse ::: acc
case (_, Nil) => xs.reverse ::: acc
case (x :: xs1, y :: ys1) =>
if (less(x, y)) merge(xs1, ys, x :: acc)
else merge(xs, ys1, y :: acc)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs), Nil).reverse
}
}
Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream
. The following code uses Stream
just to prevent stack overflow, and List
elsewhere:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
case _ => if (left.isEmpty) right.toStream else left.toStream
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs)).toList
}
}
Just playing around with scala's TailCalls
(trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.
import scala.util.control.TailCalls._
def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {
def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
if (a.isEmpty) {
done(b.reverse ::: s)
} else if (b.isEmpty) {
done(a.reverse ::: s)
} else if (a.head<b.head) {
tailcall(build(a.head::s,a.tail,b))
} else {
tailcall(build(b.head::s,a,b.tail))
}
}
build(List(),x,y).result.reverse
}
Runs just as fast as the mutable version on big List[Long]
s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With