Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to change the functional insert-sort code to be tail recursive

Recently I implement insert_sort algorithm by functional programming style,and it become more concise and declarative. the question is how to change it to be tail recursive, the code will throw exception if the size of list grows up to 10000.

def InsertSort(xs: List[Int]): List[Int] = xs match {
    case Nil => Nil
    case x::rest => 
       def insert (x: Int, sorted_xs:List[Int]) :List[Int] = sorted_xs match{
           case Nil => List(x)
           case y::ys => if  (x <= y) x::y::ys else y::insert(x,ys)
       }
       insert(x,InsertSort(rest))
 }
like image 631
eita Avatar asked Dec 06 '14 16:12

eita


2 Answers

Just introduced accumulators:

 @tailrec def InsertSort(xs: List[Int], acc: List[Int] = Nil): List[Int] = 
  if (xs.nonEmpty) {
    val x :: rest = xs
    @tailrec 
    def insert(x: Int, sorted_xs: List[Int], acc: List[Int] = Nil): List[Int] =
      if (sorted_xs.nonEmpty) { 
        val y :: ys = sorted_xs
        if (x <= y) acc ::: x :: y :: ys else insert(x,ys, acc :+ y)
      } else acc ::: List(x)
    InsertSort(rest, insert(x, acc))
  } else acc

::: and :+ will take O(n) for the default List implementation, so it's better to use some more appropriate collection (like ListBuffer). You can also rewrite it with foldLeft instead of explicit recursion.

Faster option (with foldLeft, without :+):

 @tailrec
 def insert(sorted_xs: List[Int], x: Int, acc: List[Int] = Nil): List[Int] =
   if (sorted_xs.nonEmpty) { 
     val y::ys = sorted_xs
     if (x <= y) acc.reverse ::: x :: y :: ys else insert(ys, x, y :: acc)
   } else (x :: acc).reverse

 scala> List(1,5,3,6,9,6,7).foldLeft(List[Int]())(insert(_, _))
 res22: List[Int] = List(1, 3, 5, 6, 6, 7, 9)

And finally with span (like in @roterl's answer, but span is a little faster - it traverses collection only until > x is found):

 def insert(sorted_xs: List[Int], x: Int) = if (sorted_xs.nonEmpty) { 
    val (smaller, larger) = sorted_xs.span(_ < x)
    smaller ::: x :: larger
 } else x :: Nil

 scala> List(1,5,3,6,9,6,7).foldLeft(List[Int]())(insert)
 res25: List[Int] = List(1, 3, 5, 6, 6, 7, 9)
like image 111
dk14 Avatar answered Sep 23 '22 20:09

dk14


To make it tail recursive you should pass the sorted list as parameter instead of build it at the return value:

def InsertSort(xs: List[Int]): List[Int] = {
  @tailrec
  def doSort(unsortXs: List[Int], sorted_xs: List[Int]): List[Int] = {
    unsortXs match {
      case Nil => sorted_xs
      case x::rest => 
        val (smaller, larger) = sorted_xs.partition(_ < x)
        doSort(rest, smaller ::: x :: larger)
    }
  }
  doSort(xs, List())  
}
like image 25
roterl Avatar answered Sep 21 '22 20:09

roterl