On occasion I take some time to play with Scala, whose mix of features appeals to me despite an inability to use it in my own work (thus far). For kicks I decided to try the first few 99 Haskell Problems in the most generic way possible — operating on and returning any kind of applicable collection. The first few questions weren’t too difficult, but I find myself utterly stymied by flatten
. I just can’t figure out how to type such a thing.
To be specific about my question: is it possible to write a type-safe function that flattens arbitrarily-nested SeqLike
s? So that, say,
flatten(List(Array(List(1, 2, 3), List(4, 5, 6)), Array(List(7, 8, 9), List(10, 11, 12))))
would return
List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12): List[Int]
? Note that this isn’t quite the same question as in the Haskell and Scala problem sets; I’m trying to write a function that flattens not heterogeneous lists but, rather, homogeneous-but-nested sequences.
Searching the web I found a translation into Scala of that question, but it operates on and returns a List[Any]. Am I correct that this would require some kind of type recursion? Or am I making this out to be harder than it is?
The following works in Scala 2.10.0-M7. You will need to add extra cases for Array
support, and perhaps refine it to have more specific output collection types, but I guess it can all be done starting from here:
sealed trait InnerMost {
implicit def innerSeq[A]: CanFlatten[Seq[A]] { type Elem = A } =
new CanFlatten[Seq[A]] {
type Elem = A
def flatten(seq: Seq[A]): Seq[A] = seq
}
}
object CanFlatten extends InnerMost {
implicit def nestedSeq[A](implicit inner: CanFlatten[A])
: CanFlatten[Seq[A]] { type Elem = inner.Elem } =
new CanFlatten[Seq[A]] {
type Elem = inner.Elem
def flatten(seq: Seq[A]): Seq[inner.Elem] =
seq.flatMap(a => inner.flatten(a))
}
}
sealed trait CanFlatten[-A] {
type Elem
def flatten(seq: A): Seq[Elem]
}
implicit final class FlattenOp[A](val seq: A)(implicit val can: CanFlatten[A]) {
def flattenAll: Seq[can.Elem] = can.flatten(seq)
}
// test
assert(List(1, 2, 3).flattenAll == Seq(1, 2, 3))
assert(List(Seq(List(1, 2, 3), List(4, 5, 6)), Seq(List(7, 8, 9),
List(10, 11, 12))).flattenAll == (1 to 12).toSeq)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With