I have a list of lists in Scala as follows.
val inputList:List[List[Int]] = List(List(1, 2), List(3, 4, 5), List(1, 9))
I want a list of cross products of all the sub-lists.
val desiredOutput: List[List[Int]] = List(
List(1, 3, 1), List(1, 3, 9),
List(1, 4, 1), List(1, 4, 9),
List(1, 5, 1), List(1, 5, 9),
List(2, 3, 1), List(2, 3, 9),
List(2, 4, 1), List(2, 4, 9),
List(2, 5, 1), List(2, 5, 9))
The number of elements in inputList as well as the sublist are not fixed. What is the Scala way of doing this?
Here is a method that works using recursion. However it is not tail recursive so beware of stackoverflows. However it can be transformed to a tail recursive function by using an auxiliary function.
def getProduct(input:List[List[Int]]):List[List[Int]] = input match{
case Nil => Nil // just in case you input an empty list
case head::Nil => head.map(_::Nil)
case head::tail => for(elem<- head; sub <- getProduct(tail)) yield elem::sub
}
Test:
scala> getProduct(inputList)
res32: List[List[Int]] = List(List(1, 3, 1), List(1, 3, 9), List(1, 4, 1), List(1, 4, 9), List(1, 5, 1), List(1, 5, 9), List(2, 3, 1), List(2, 3, 9), List(2, 4, 1), List(2, 4, 9), List(2, 5, 1), List(2, 5, 9))
If you use scalaz
, this may be a suitable case for Applicative Builder
:
import scalaz._
import Scalaz._
def desiredOutput(input: List[List[Int]]) =
input.foldLeft(List(List.empty[Int]))((l, r) => (l |@| r)(_ :+ _))
desiredOutput(List(List(1, 2), List(3, 4, 5), List(1, 9)))
I am not very familiar with scalaz myself, and I expect it has some more powerful magic to do this.
Edit
As Travis Brown suggest, we just write
def desiredOutput(input: List[List[Int]]) = input.sequence
And I find the answers of this question very helpful for understanding what sequence
does.
If you don't mind a bit of functional programming:
def cross[T](inputs: List[List[T]]) : List[List[T]] =
inputs.foldRight(List[List[T]](Nil))((el, rest) => el.flatMap(p => rest.map(p :: _)))
Much fun finding out how that works. :-)
After several attempts, I arrived at this solution.
val inputList: List[List[Int]] = List(List(1, 2), List(3, 4, 5), List(1, 9))
val zss: List[List[Int]] = List(List())
def fun(xs: List[Int], zss: List[List[Int]]): List[List[Int]] = {
for {
x <- xs
zs <- zss
} yield {
x :: zs
}
}
val crossProd: List[List[Int]] = inputList.foldRight(zss)(fun _)
println(crossProd)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With