Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

lazy sorts of iterators in Scala?

I've read that in haskell, when sorting an iterator, it only evaluates as much of the qsort as necessary to return the number of values actually evaluated on the resulting iterator (i.e, it is lazy, i.e., once it has completed the LHS of the first pivot and can return one value, it can provide that one value on a call to "next" on the iterator and not continue pivoting unless next is called again).

For example, in haskell, head(qsort list) is O(n). It just finds the minimum value in the list, and doesn't sort the rest of the list unless the rest of the result of the qsort list is accessed.

Is there a way to do this in Scala? I want to use sortWith on a collection but only sort as much as necessary, such that I can mySeq.sortWith(<).take(3) and have it not need to complete the sort operation.

I'd like to know if other sort functions (like sortBy) can be used in a lazy way, and how to ensure laziness, and how to find any other documentation about when sorts in Scala are or are not lazily evaluated.

UPDATE/EDIT: I'm ideally looking for ways to do this with standard sorting functions like sortWith. I'd rather not have to implement my own version of quicksort just to get lazy evaluation. Shouldn't this be built into the standard library, at least for collections like Stream that support laziness??

like image 628
nairbv Avatar asked Sep 28 '12 11:09

nairbv


Video Answer


2 Answers

I've used Scala's priority queue implementation to solve this kind of partial sorting problem:

import scala.collection.mutable.PriorityQueue

val q = PriorityQueue(1289, 12, 123, 894, 1)(Ordering.Int.reverse)

Now we can call dequeue:

scala> q.dequeue
res0: Int = 1

scala> q.dequeue
res1: Int = 12

scala> q.dequeue
res2: Int = 123

It costs O(n) to build the queue and O(k log n) to take the first k elements.

Unfortunately PriorityQueue doesn't iterate in priority order, but it's not too hard to write an iterator that calls dequeue.

like image 197
Travis Brown Avatar answered Oct 22 '22 08:10

Travis Brown


As an example, I created an implementation of lazy quick-sort that creates a lazy tree structure (instead of producing a result list). This structure can be asked for any i-th element in O(n) time or a slice of k elements. Asking the same element again (or an nearby element) takes only O(log n) as the tree structure built in the previous step is reused. Traversing all elements takes O(n log n) time. (All assuming that we've chosen reasonable pivots.)

The key is that subtrees are not built right away, they're delayed in a lazy computation. So when asking only for a single element, the root node is computed in O(n), then one of its sub-nodes in O(n/2) etc. until the required element is found, taking O(n + n/2 + n/4 ...) = O(n). When the tree is fully evaluated, picking any element takes O(log n) as with any balanced tree.

Note that the implementation of build is quite inefficient. I wanted it to be simple and as easy to understand as possible. The important thing is that it has the proper asymptotic bounds.

import collection.immutable.Traversable

object LazyQSort {
  /**
   * Represents a value that is evaluated at most once.
   */
  final protected class Thunk[+A](init: => A) extends Function0[A] {
    override lazy val apply: A = init;
  }

  implicit protected def toThunk[A](v: => A): Thunk[A] = new Thunk(v);
  implicit protected def fromThunk[A](t: Thunk[A]): A = t.apply;

  // -----------------------------------------------------------------

  /**
   * A lazy binary tree that keeps a list of sorted elements.
   * Subtrees are created lazily using `Thunk`s, so only
   * the necessary part of the whole tree is created for
   * each operation.
   *
   * Most notably, accessing any i-th element using `apply`
   * takes O(n) time and traversing all the elements
   * takes O(n * log n) time.
   */
  sealed abstract class Tree[+A]
    extends Function1[Int,A] with Traversable[A]
  {
    override def apply(i: Int) = findNth(this, i);

    override def head: A = apply(0);
    override def last: A = apply(size - 1);
    def max: A = last;
    def min: A = head;
    override def slice(from: Int, until: Int): Traversable[A] =
      LazyQSort.slice(this, from, until);
    // We could implement more Traversable's methods here ...
  }
  final protected case class Node[+A](
      pivot: A, leftSize: Int, override val size: Int,
      left: Thunk[Tree[A]], right: Thunk[Tree[A]]
    ) extends Tree[A]
  {
    override def foreach[U](f: A => U): Unit = {
      left.foreach(f);
      f(pivot);
      right.foreach(f);
    }
    override def isEmpty: Boolean = false;
  }
  final protected case object Leaf extends Tree[Nothing] {
    override def foreach[U](f: Nothing => U): Unit = {}
    override def size: Int = 0;
    override def isEmpty: Boolean = true;
  }

  // -----------------------------------------------------------------

  /**
   * Finds i-th element of the tree.
   */
  @annotation.tailrec
  protected def findNth[A](tree: Tree[A], n: Int): A =
    tree match {
      case Leaf => throw new ArrayIndexOutOfBoundsException(n);
      case Node(pivot, lsize, _, l, r)
                => if (n == lsize) pivot
                   else if (n < lsize) findNth(l, n)
                   else findNth(r, n - lsize - 1);
    }

  /**
   * Cuts a given subinterval from the data.
   */
  def slice[A](tree: Tree[A], from: Int, until: Int): Traversable[A] =
    tree match {
      case Leaf => Leaf
      case Node(pivot, lsize, size, l, r) => {
        lazy val sl = slice(l, from, until);
        lazy val sr = slice(r, from - lsize - 1, until - lsize - 1);
        if ((until <= 0) || (from >= size)) Leaf // empty
        if (until <= lsize) sl
        else if (from > lsize) sr
        else sl ++ Seq(pivot) ++ sr
      }
  }

  // -----------------------------------------------------------------

  /**
   * Builds a tree from a given sequence of data.
   */
  def build[A](data: Seq[A])(implicit ord: Ordering[A]): Tree[A] =
    if (data.isEmpty) Leaf
    else {
      // selecting a pivot is traditionally a complex matter,
      // for simplicity we take the middle element here
      val pivotIdx = data.size / 2;
      val pivot = data(pivotIdx);
      // this is far from perfect, but still linear
      val (l, r) = data.patch(pivotIdx, Seq.empty, 1).partition(ord.lteq(_, pivot));
      Node(pivot, l.size, data.size, { build(l) }, { build(r) });
    }
}

// ###################################################################

/**
 * Tests some operations and prints results to stdout.
 */
object LazyQSortTest extends App {
  import util.Random
  import LazyQSort._

  def trace[A](name: String, comp: => A): A = {
    val start = System.currentTimeMillis();
    val r: A = comp;
    val end = System.currentTimeMillis();
    println("-- " + name + " took " + (end - start) + "ms");
    return r;
  }

  {
    val n = 1000000;
    val rnd = Random.shuffle(0 until n);
    val tree = build(rnd);
    trace("1st element", println(tree.head));
    // Second element is much faster since most of the required
    // structure is already built
    trace("2nd element", println(tree(1)));
    trace("Last element", println(tree.last));
    trace("Median element", println(tree(n / 2)));
    trace("Median + 1 element", println(tree(n / 2 + 1)));
    trace("Some slice", for(i <- tree.slice(n/2, n/2+30)) println(i));
    trace("Traversing all elements", for(i <- tree) i);
    trace("Traversing all elements again", for(i <- tree) i);
  }
}

The output will be something like

0
-- 1st element took 268ms
1
-- 2nd element took 0ms
999999
-- Last element took 39ms
500000
-- Median element took 122ms
500001
-- Median + 1 element took 0ms
500000
  ...
500029
-- Slice took 6ms
-- Traversing all elements took 7904ms
-- Traversing all elements again took 191ms
like image 1
Petr Avatar answered Oct 22 '22 08:10

Petr