What's a fast implementation of median in scala?
This is what I found on rosetta code:
def median(s: Seq[Double]) = { val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2) if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head }
I don't like it because it does a sort. I know there are ways to compute the median in linear time.
EDIT:
I would like to have a set of median functions that I can use in various scenarios:
O(log n)
values in memory like this O(log n)
values in memory, and you can traverse the stream at most once (is this even possible?)Please only post code that compiles and correctly computes the median. For simplicity, you may assume that all inputs contain an odd number of values.
The first algorithm indicated by Taylor Leese is quadratic, but has linear average. That, however, depends on the pivot selection. So I'm providing here a version which has a pluggable pivot selection, and both the random pivot and the median of medians pivot (which guarantees linear time).
import scala.annotation.tailrec @tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = { val a = choosePivot(arr) val (s, b) = arr partition (a >) if (s.size == k) a // The following test is used to avoid infinite repetition else if (s.isEmpty) { val (s, b) = arr partition (a ==) if (s.size > k) a else findKMedian(b, k - s.size) } else if (s.size < k) findKMedian(b, k - s.size) else findKMedian(s, k) } def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)
This is the random pivot selection. Analysis of algorithms with random factors is trickier than normal, because it deals largely with probability and statistics.
def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))
The median of medians method, which guarantees linear time when used with the algorithm above. First, and algorithm to compute the median of up to 5 numbers, which is the basis of the median of medians algorithm. This one was provided by Rex Kerr in this answer -- the algorithm depends a lot on the speed of it.
def medianUpTo5(five: Array[Double]): Double = { def order2(a: Array[Double], i: Int, j: Int) = { if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t } } def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = { if (a(i)<a(k)) { order2(a,j,k); a(j) } else { order2(a,i,l); a(i) } } if (five.length < 2) return five(0) order2(five,0,1) if (five.length < 4) return ( if (five.length==2 || five(2) < five(0)) five(0) else if (five(2) > five(1)) five(1) else five(2) ) order2(five,2,3) if (five.length < 5) pairs(five,0,1,2,3) else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) } else { order2(five,3,4); pairs(five,0,1,3,4) } }
And, then, the median of medians algorithm itself. Basically, it guarantees that the choosen pivot will be greater than at least 30% and smaller than other 30% of the list, which is enough to guarantee the linearity of the previous algorithm. Look up the wikipedia link provided in another answer for details.
def medianOfMedians(arr: Array[Double]): Double = { val medians = arr grouped 5 map medianUpTo5 toArray; if (medians.size <= 5) medianUpTo5 (medians) else medianOfMedians(medians) }
So, here's an in-place version of the algorithm. I'm using a class that implements a partition in-place, with a backing array, so that the changes to the algorithms are minimal.
case class ArrayView(arr: Array[Double], from: Int, until: Int) { def apply(n: Int) = if (from + n < until) arr(from + n) else throw new ArrayIndexOutOfBoundsException(n) def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = { var upper = until - 1 var lower = from while (lower < upper) { while (lower < until && p(arr(lower))) lower += 1 while (upper >= from && !p(arr(upper))) upper -= 1 if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp } } (copy(until = lower), copy(from = lower)) } def size = until - from def isEmpty = size <= 0 override def toString = arr mkString ("ArraySize(", ", ", ")") }; object ArrayView { def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size) } @tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = { val a = choosePivot(arr) val (s, b) = arr partitionInPlace (a >) if (s.size == k) a // The following test is used to avoid infinite repetition else if (s.isEmpty) { val (s, b) = arr partitionInPlace (a ==) if (s.size > k) a else findKMedianInPlace(b, k - s.size) } else if (s.size < k) findKMedianInPlace(b, k - s.size) else findKMedianInPlace(s, k) } def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)
I'm only implementing the radom pivot for the in-place algorithms, as the median of medians would require more support than what is presently provided by the ArrayView
class I defined.
def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))
So, about streams. It is impossible to do anything less than O(n)
memory for a stream that can only be traversed once, unless you happen to know what the string length is (in which case it ceases to be a stream in my book).
Using buckets is also a bit problematic, but if we can traverse it multiple times, then we can know its size, maximum and minimum, and work from there. For example:
def findMedianHistogram(s: Traversable[Double]) = { def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = { // The buckets def numberOfBuckets = (math.log(s.size).toInt + 1) max 2 val buckets = new Array[Int](numberOfBuckets) // The upper limit of each bucket val max = s.max val min = s.min val increment = (max - min) / numberOfBuckets val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _) // Return the bucket a number is supposed to be in def bucketIndex(d: Double) = indices indexWhere (d <=) // Compute how many in each bucket s foreach { d => buckets(bucketIndex(d)) += 1 } // Now make the buckets cumulative val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1) // The bucket where our target is at val medianBucket = partialTotals indexWhere (medianIndex <) // Keep track of how many numbers there are that are less // than the median bucket val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1) // Test whether a number is in the median bucket def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket // Get a view of the target bucket val view = s.view filter insideMedianBucket // If all numbers in the bucket are equal, return that if (view forall (view.head ==)) view.head // Otherwise, recurse on that bucket else medianHistogram(view, newDiscarded, medianIndex) } medianHistogram(s, 0, (s.size - 1) / 2) }
To test the algorithms, I'm using Scalacheck, and comparing the output of each algorithm to the output of a trivial implementation with sorting. That assumes the sorting version is correct, of course.
I'm benchmarking each of the above algorithms with all provided pivot selections, plus a fixed pivot selection (halfway the array, round down). Each algorithm is tested with three different input array sizes, and for three times against each one.
Here's the testing code:
import org.scalacheck.{Prop, Pretty, Test} import Prop._ import Pretty._ def test(algorithm: Array[Double] => Double, reference: Array[Double] => Double): String = { def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")") val resultEqualsReference = forAll { (arr: Array[Double]) => arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr) } Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0)) } import java.lang.System.currentTimeMillis def bench[A](n: Int)(body: => A): Long = { val start = currentTimeMillis() 1 to n foreach { _ => body } currentTimeMillis() - start } import scala.util.Random.nextDouble def benchmark(algorithm: Array[Double] => Double, arraySizes: List[Int]): List[Iterable[Long]] = for (size <- arraySizes) yield for (iteration <- 1 to 3) yield bench(50000)(algorithm(Array.fill(size)(nextDouble))) def testAndBenchmark: String = { val immutablePivotSelection: List[(String, Array[Double] => Double)] = List( "Random Pivot" -> chooseRandomPivot, "Median of Medians" -> medianOfMedians, "Midpoint" -> ((arr: Array[Double]) => arr((arr.size - 1) / 2)) ) val inPlacePivotSelection: List[(String, ArrayView => Double)] = List( "Random Pivot (in-place)" -> chooseRandomPivotInPlace, "Midpoint (in-place)" -> ((arr: ArrayView) => arr((arr.size - 1) / 2)) ) val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection) yield name -> (findMedian(_: Array[Double])(pivotSelection)) val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection) yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection)) val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr)) val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2)) val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms val formattingString = "%%-%ds %%s" format (algorithms map (_._1.length) max) // Tests val testResults = for ((name, algorithm) <- algorithms) yield formattingString format (name, test(algorithm, sortingAlgorithm._2)) // Benchmarks val arraySizes = List(100, 500, 1000) def formatResults(results: List[Long]) = results map ("%8d" format _) mkString val benchmarkResults: List[String] = for { (name, algorithm) <- algorithms results <- benchmark(algorithm, arraySizes).transpose } yield formattingString format (name, formatResults(results)) val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong))) "Tests" :: "*****" :: testResults ::: ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n") }
Tests:
Tests ***** Sorting OK, passed 100 tests. Histogram OK, passed 100 tests. Random Pivot OK, passed 100 tests. Median of Medians OK, passed 100 tests. Midpoint OK, passed 100 tests. Random Pivot (in-place)OK, passed 100 tests. Midpoint (in-place) OK, passed 100 tests.
Benchmarks:
Benchmark ********* Algorithm 100 500 1000 Sorting 1038 6230 14034 Sorting 1037 6223 13777 Sorting 1039 6220 13785 Histogram 2918 11065 21590 Histogram 2596 11046 21486 Histogram 2592 11044 21606 Random Pivot 904 4330 8622 Random Pivot 902 4323 8815 Random Pivot 896 4348 8767 Median of Medians 3591 16857 33307 Median of Medians 3530 16872 33321 Median of Medians 3517 16793 33358 Midpoint 1003 4672 9236 Midpoint 1010 4755 9157 Midpoint 1017 4663 9166 Random Pivot (in-place) 392 1746 3430 Random Pivot (in-place) 386 1747 3424 Random Pivot (in-place) 386 1751 3431 Midpoint (in-place) 378 1735 3405 Midpoint (in-place) 377 1740 3408 Midpoint (in-place) 375 1736 3408
All algorithms (except the sorting version) have results that are compatible with average linear time complexity.
The median of medians, which guarantees linear time complexity in the worst case is much slower than the random pivot.
The fixed pivot selection is slightly worse than random pivot, but may have much worse performance on non-random inputs.
The in-place version is about 230% ~ 250% faster, but further tests (not shown) seem to indicate this advantage grows with the size of the array.
I was very surprised by the histogram algorithm. It displayed linear time complexity average, and it's also 33% faster than the median of medians. However, the input is random. The worst case is quadratic -- I saw some examples of it while I was debugging the code.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With