Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Partition a collection into "k" close-to-equal pieces (Scala, but language agnostic)

Defined before this block of code:

  • dataset can be a Vector or List
  • numberOfSlices is an Int denoting how many "times" to slice dataset

I want to split the dataset into numberOfSlices slices, distributed as evenly as possible. By "split" I guess I mean "partition" (intersection of all should be empty, union of all should be the original) to use the set theory term, though this is not necessarily a set, just an arbitrary collection.

e.g.

dataset = List(1, 2, 3, 4, 5, 6, 7)
numberOfSlices = 3
slices == ListBuffer(Vector(1, 2), Vector(3, 4), Vector(5, 6, 7))

Is there a better way to do it than what I have below? (which I'm not even sure is optimal...) Or perhaps this is not an algorithmically feasible endeavor, in which case any known good heuristics?

val slices = new ListBuffer[Vector[Int]]
val stepSize = dataset.length / numberOfSlices
var currentStep = 0
var looper = 0
while (looper != numberOfSlices) {
  if (looper != numberOfSlices - 1) {
    slices += dataset.slice(currentStep, currentStep + stepSize)
    currentStep += stepSize
  } else {
    slices += dataset.slice(currentStep, dataset.length)
  }
  looper += 1
}
like image 661
adelbertc Avatar asked Jul 12 '12 16:07

adelbertc


3 Answers

If the behavior of xs.grouped(xs.size / n) doesn't work for you, it's pretty easy to define exactly what you want. The quotient is the size of the smaller pieces, and the remainder is the number of the bigger pieces:

def cut[A](xs: Seq[A], n: Int) = {
  val (quot, rem) = (xs.size / n, xs.size % n)
  val (smaller, bigger) = xs.splitAt(xs.size - rem * (quot + 1))
  smaller.grouped(quot) ++ bigger.grouped(quot + 1)
}
like image 50
Travis Brown Avatar answered Nov 02 '22 04:11

Travis Brown


The typical "optimal" partition calculates an exact fractional length after cutting and then rounds to find the actual number to take:

def cut[A](xs: Seq[A], n: Int):Vector[Seq[A]] = {
  val m = xs.length
  val targets = (0 to n).map{x => math.round((x.toDouble*m)/n).toInt}
  def snip(xs: Seq[A], ns: Seq[Int], got: Vector[Seq[A]]): Vector[Seq[A]] = {
    if (ns.length<2) got
    else {
      val (i,j) = (ns.head, ns.tail.head)
      snip(xs.drop(j-i), ns.tail, got :+ xs.take(j-i))
    }
  }
  snip(xs, targets, Vector.empty)
}

This way your longer and shorter blocks will be interspersed, which is often more desirable for evenness:

scala> cut(List(1,2,3,4,5,6,7,8,9,10),4)
res5: Vector[Seq[Int]] = 
  Vector(List(1, 2, 3), List(4, 5), List(6, 7, 8), List(9, 10))

You can even cut more times than you have elements:

scala> cut(List(1,2,3),5)
res6: Vector[Seq[Int]] = 
  Vector(List(1), List(), List(2), List(), List(3))
like image 7
Rex Kerr Avatar answered Nov 02 '22 05:11

Rex Kerr


Here's a one-liner that does the job for me, using the familiar Scala trick of a recursive function that returns a Stream. Notice the use of (x+k/2)/k to round the chunk sizes, intercalating the smaller and larger chunks in the final list, all with sizes with at most one element of difference. If you round up instead, with (x+k-1)/k, you move the smaller blocks to the end, and x/k moves them to the beginning.

def k_folds(k: Int, vv: Seq[Int]): Stream[Seq[Int]] =
    if (k > 1)
        vv.take((vv.size+k/2)/k) +: k_folds(k-1, vv.drop((vv.size+k/2)/k))
    else
        Stream(vv)

Demo:

scala> val indices = scala.util.Random.shuffle(1 to 39)

scala> for (ff <- k_folds(7, indices)) println(ff)
Vector(29, 8, 24, 14, 22, 2)
Vector(28, 36, 27, 7, 25, 4)
Vector(6, 26, 17, 13, 23)
Vector(3, 35, 34, 9, 37, 32)
Vector(33, 20, 31, 11, 16)
Vector(19, 30, 21, 39, 5, 15)
Vector(1, 38, 18, 10, 12)

scala> for (ff <- k_folds(7, indices)) println(ff.size)
6
6
5
6
5
6
5

scala> for (ff <- indices.grouped((indices.size+7-1)/7)) println(ff)
Vector(29, 8, 24, 14, 22, 2)
Vector(28, 36, 27, 7, 25, 4)
Vector(6, 26, 17, 13, 23, 3)
Vector(35, 34, 9, 37, 32, 33)
Vector(20, 31, 11, 16, 19, 30)
Vector(21, 39, 5, 15, 1, 38)
Vector(18, 10, 12)

scala> for (ff <- indices.grouped((indices.size+7-1)/7)) println(ff.size)
6
6
6
6
6
6
3

Notice how grouped does not try to even out the size of all the sub-lists.

like image 4
dividebyzero Avatar answered Nov 02 '22 05:11

dividebyzero