Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

scala median implementation

What's a fast implementation of median in scala?

This is what I found on rosetta code:

  def median(s: Seq[Double])  =   {     val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)     if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head   } 

I don't like it because it does a sort. I know there are ways to compute the median in linear time.

EDIT:

I would like to have a set of median functions that I can use in various scenarios:

  1. fast, in place median computation that can be done in linear time
  2. median that works on a stream that you can traverse multiple times, but you can only keep O(log n) values in memory like this
  3. median that works on a stream, where you can hold at most O(log n) values in memory, and you can traverse the stream at most once (is this even possible?)

Please only post code that compiles and correctly computes the median. For simplicity, you may assume that all inputs contain an odd number of values.

like image 712
dsg Avatar asked Jan 11 '11 20:01

dsg


1 Answers

Immutable Algorithm

The first algorithm indicated by Taylor Leese is quadratic, but has linear average. That, however, depends on the pivot selection. So I'm providing here a version which has a pluggable pivot selection, and both the random pivot and the median of medians pivot (which guarantees linear time).

import scala.annotation.tailrec  @tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {     val a = choosePivot(arr)     val (s, b) = arr partition (a >)     if (s.size == k) a     // The following test is used to avoid infinite repetition     else if (s.isEmpty) {         val (s, b) = arr partition (a ==)         if (s.size > k) a         else findKMedian(b, k - s.size)     } else if (s.size < k) findKMedian(b, k - s.size)     else findKMedian(s, k) }  def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2) 

Random Pivot (quadratic, linear average), Immutable

This is the random pivot selection. Analysis of algorithms with random factors is trickier than normal, because it deals largely with probability and statistics.

def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size)) 

Median of Medians (linear), Immutable

The median of medians method, which guarantees linear time when used with the algorithm above. First, and algorithm to compute the median of up to 5 numbers, which is the basis of the median of medians algorithm. This one was provided by Rex Kerr in this answer -- the algorithm depends a lot on the speed of it.

def medianUpTo5(five: Array[Double]): Double = {   def order2(a: Array[Double], i: Int, j: Int) = {     if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }   }    def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {     if (a(i)<a(k)) { order2(a,j,k); a(j) }     else { order2(a,i,l); a(i) }   }    if (five.length < 2) return five(0)   order2(five,0,1)   if (five.length < 4) return (     if (five.length==2 || five(2) < five(0)) five(0)     else if (five(2) > five(1)) five(1)     else five(2)   )   order2(five,2,3)   if (five.length < 5) pairs(five,0,1,2,3)   else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }   else { order2(five,3,4); pairs(five,0,1,3,4) } } 

And, then, the median of medians algorithm itself. Basically, it guarantees that the choosen pivot will be greater than at least 30% and smaller than other 30% of the list, which is enough to guarantee the linearity of the previous algorithm. Look up the wikipedia link provided in another answer for details.

def medianOfMedians(arr: Array[Double]): Double = {     val medians = arr grouped 5 map medianUpTo5 toArray;     if (medians.size <= 5) medianUpTo5 (medians)     else medianOfMedians(medians) } 

In-place Algorithm

So, here's an in-place version of the algorithm. I'm using a class that implements a partition in-place, with a backing array, so that the changes to the algorithms are minimal.

case class ArrayView(arr: Array[Double], from: Int, until: Int) {     def apply(n: Int) =          if (from + n < until) arr(from + n)         else throw new ArrayIndexOutOfBoundsException(n)      def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {       var upper = until - 1       var lower = from       while (lower < upper) {         while (lower < until && p(arr(lower))) lower += 1         while (upper >= from && !p(arr(upper))) upper -= 1         if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }       }       (copy(until = lower), copy(from = lower))     }      def size = until - from     def isEmpty = size <= 0      override def toString = arr mkString ("ArraySize(", ", ", ")") }; object ArrayView {     def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size) }  @tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {     val a = choosePivot(arr)     val (s, b) = arr partitionInPlace (a >)     if (s.size == k) a     // The following test is used to avoid infinite repetition     else if (s.isEmpty) {         val (s, b) = arr partitionInPlace (a ==)         if (s.size > k) a         else findKMedianInPlace(b, k - s.size)     } else if (s.size < k) findKMedianInPlace(b, k - s.size)     else findKMedianInPlace(s, k) }  def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2) 

Random Pivot, In-place

I'm only implementing the radom pivot for the in-place algorithms, as the median of medians would require more support than what is presently provided by the ArrayView class I defined.

def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size)) 

Histogram Algorithm (O(log(n)) memory), Immutable

So, about streams. It is impossible to do anything less than O(n) memory for a stream that can only be traversed once, unless you happen to know what the string length is (in which case it ceases to be a stream in my book).

Using buckets is also a bit problematic, but if we can traverse it multiple times, then we can know its size, maximum and minimum, and work from there. For example:

def findMedianHistogram(s: Traversable[Double]) = {     def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {         // The buckets         def numberOfBuckets = (math.log(s.size).toInt + 1) max 2         val buckets = new Array[Int](numberOfBuckets)          // The upper limit of each bucket         val max = s.max         val min = s.min         val increment = (max - min) / numberOfBuckets         val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)          // Return the bucket a number is supposed to be in         def bucketIndex(d: Double) = indices indexWhere (d <=)          // Compute how many in each bucket         s foreach { d => buckets(bucketIndex(d)) += 1 }          // Now make the buckets cumulative         val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)          // The bucket where our target is at         val medianBucket = partialTotals indexWhere (medianIndex <)          // Keep track of how many numbers there are that are less          // than the median bucket         val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)          // Test whether a number is in the median bucket         def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket          // Get a view of the target bucket         val view = s.view filter insideMedianBucket          // If all numbers in the bucket are equal, return that         if (view forall (view.head ==)) view.head         // Otherwise, recurse on that bucket         else medianHistogram(view, newDiscarded, medianIndex)     }      medianHistogram(s, 0, (s.size - 1) / 2) } 

Test and Benchmark

To test the algorithms, I'm using Scalacheck, and comparing the output of each algorithm to the output of a trivial implementation with sorting. That assumes the sorting version is correct, of course.

I'm benchmarking each of the above algorithms with all provided pivot selections, plus a fixed pivot selection (halfway the array, round down). Each algorithm is tested with three different input array sizes, and for three times against each one.

Here's the testing code:

import org.scalacheck.{Prop, Pretty, Test} import Prop._ import Pretty._  def test(algorithm: Array[Double] => Double,           reference: Array[Double] => Double): String = {     def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")     val resultEqualsReference = forAll { (arr: Array[Double]) =>          arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)     }     Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0)) }  import java.lang.System.currentTimeMillis  def bench[A](n: Int)(body: => A): Long = {   val start = currentTimeMillis()   1 to n foreach { _ => body }   currentTimeMillis() - start }  import scala.util.Random.nextDouble  def benchmark(algorithm: Array[Double] => Double,               arraySizes: List[Int]): List[Iterable[Long]] =      for (size <- arraySizes)     yield for (iteration <- 1 to 3)         yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))  def testAndBenchmark: String = {     val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(         "Random Pivot"      -> chooseRandomPivot,         "Median of Medians" -> medianOfMedians,         "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))     )     val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(         "Random Pivot (in-place)" -> chooseRandomPivotInPlace,         "Midpoint (in-place)"     -> ((arr: ArrayView) => arr((arr.size - 1) / 2))     )     val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)         yield name -> (findMedian(_: Array[Double])(pivotSelection))     val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)         yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))     val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))     val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))     val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms      val formattingString = "%%-%ds  %%s" format (algorithms map (_._1.length) max)      // Tests     val testResults = for ((name, algorithm) <- algorithms)         yield formattingString format (name, test(algorithm, sortingAlgorithm._2))      // Benchmarks     val arraySizes = List(100, 500, 1000)     def formatResults(results: List[Long]) = results map ("%8d" format _) mkString      val benchmarkResults: List[String] = for {         (name, algorithm) <- algorithms         results <- benchmark(algorithm, arraySizes).transpose     } yield formattingString format (name, formatResults(results))      val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))      "Tests" :: "*****" :: testResults :::      ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n") } 

Results

Tests:

Tests ***** Sorting                OK, passed 100 tests. Histogram              OK, passed 100 tests. Random Pivot           OK, passed 100 tests. Median of Medians      OK, passed 100 tests. Midpoint               OK, passed 100 tests. Random Pivot (in-place)OK, passed 100 tests. Midpoint (in-place)    OK, passed 100 tests. 

Benchmarks:

Benchmark ********* Algorithm                   100     500    1000 Sorting                    1038    6230   14034 Sorting                    1037    6223   13777 Sorting                    1039    6220   13785 Histogram                  2918   11065   21590 Histogram                  2596   11046   21486 Histogram                  2592   11044   21606 Random Pivot                904    4330    8622 Random Pivot                902    4323    8815 Random Pivot                896    4348    8767 Median of Medians          3591   16857   33307 Median of Medians          3530   16872   33321 Median of Medians          3517   16793   33358 Midpoint                   1003    4672    9236 Midpoint                   1010    4755    9157 Midpoint                   1017    4663    9166 Random Pivot (in-place)     392    1746    3430 Random Pivot (in-place)     386    1747    3424 Random Pivot (in-place)     386    1751    3431 Midpoint (in-place)         378    1735    3405 Midpoint (in-place)         377    1740    3408 Midpoint (in-place)         375    1736    3408 

Analysis

All algorithms (except the sorting version) have results that are compatible with average linear time complexity.

The median of medians, which guarantees linear time complexity in the worst case is much slower than the random pivot.

The fixed pivot selection is slightly worse than random pivot, but may have much worse performance on non-random inputs.

The in-place version is about 230% ~ 250% faster, but further tests (not shown) seem to indicate this advantage grows with the size of the array.

I was very surprised by the histogram algorithm. It displayed linear time complexity average, and it's also 33% faster than the median of medians. However, the input is random. The worst case is quadratic -- I saw some examples of it while I was debugging the code.

like image 113
16 revs Avatar answered Oct 16 '22 12:10

16 revs