Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

cross product of arbitrary number of lists in scala

Tags:

scala

I have a list of lists in Scala as follows.

val inputList:List[List[Int]] = List(List(1, 2), List(3, 4, 5), List(1, 9))

I want a list of cross products of all the sub-lists.

val desiredOutput: List[List[Int]] = List( 
        List(1, 3, 1), List(1, 3, 9),
        List(1, 4, 1), List(1, 4, 9),
        List(1, 5, 1), List(1, 5, 9),
        List(2, 3, 1), List(2, 3, 9),
        List(2, 4, 1), List(2, 4, 9),
        List(2, 5, 1), List(2, 5, 9))

The number of elements in inputList as well as the sublist are not fixed. What is the Scala way of doing this?

like image 869
dips Avatar asked Nov 26 '12 15:11

dips


4 Answers

Here is a method that works using recursion. However it is not tail recursive so beware of stackoverflows. However it can be transformed to a tail recursive function by using an auxiliary function.

def getProduct(input:List[List[Int]]):List[List[Int]] = input match{
  case Nil => Nil // just in case you input an empty list
  case head::Nil => head.map(_::Nil) 
  case head::tail => for(elem<- head; sub <- getProduct(tail)) yield elem::sub
}

Test:

scala> getProduct(inputList)
res32: List[List[Int]] = List(List(1, 3, 1), List(1, 3, 9), List(1, 4, 1), List(1, 4, 9), List(1, 5, 1), List(1, 5, 9), List(2, 3, 1), List(2, 3, 9), List(2, 4, 1), List(2, 4, 9), List(2, 5, 1), List(2, 5, 9))
like image 41
Christopher Chiche Avatar answered Nov 01 '22 11:11

Christopher Chiche


If you use scalaz, this may be a suitable case for Applicative Builder:

import scalaz._
import Scalaz._

def desiredOutput(input: List[List[Int]]) = 
  input.foldLeft(List(List.empty[Int]))((l, r) => (l |@| r)(_ :+ _))

desiredOutput(List(List(1, 2), List(3, 4, 5), List(1, 9)))

I am not very familiar with scalaz myself, and I expect it has some more powerful magic to do this.

Edit

As Travis Brown suggest, we just write

def desiredOutput(input: List[List[Int]]) = input.sequence

And I find the answers of this question very helpful for understanding what sequence does.

like image 69
xiefei Avatar answered Nov 01 '22 12:11

xiefei


If you don't mind a bit of functional programming:

def cross[T](inputs: List[List[T]]) : List[List[T]] = 
    inputs.foldRight(List[List[T]](Nil))((el, rest) => el.flatMap(p => rest.map(p :: _)))

Much fun finding out how that works. :-)

like image 2
Hans-Peter Störr Avatar answered Nov 01 '22 11:11

Hans-Peter Störr


After several attempts, I arrived at this solution.

val inputList: List[List[Int]] = List(List(1, 2), List(3, 4, 5), List(1, 9))
val zss: List[List[Int]] = List(List())
def fun(xs: List[Int], zss: List[List[Int]]): List[List[Int]] = {
    for {
        x <- xs
        zs <- zss
    } yield {
        x :: zs
    }
}
val crossProd: List[List[Int]] = inputList.foldRight(zss)(fun _)
println(crossProd)
like image 1
dips Avatar answered Nov 01 '22 10:11

dips