I'm trying to understand how to use scalaz State
to perform a complicated stateful computation. Here is the problem:
Given a
List[Int]
of potential divisors and aList[Int]
of numbers, find aList[(Int, Int)
] of matching pairs (divisor, number) where a divisor is allowed to match at most one number.
As a test:
def findMatches(divs: List[Int], nums: List[Int]): List[(Int, Int)]
And with the following input:
findMatches( List(2, 3, 4), List(1, 6, 7, 8, 9) )
We can get at most 3 matches. If we stipulate that the matches must be made in the order in which they occur traversing the lists l-r, then the matches must be:
List( (2, 6) , (3, 9) , (4, 8) )
So the following two tests need to pass:
assert(findMatches(List(2, 3, 4), List(1, 6, 7, 8, 9)) == List((2, 6), (3, 9), (4, 8)))
assert(findMatches(List(2, 3, 4), List(1, 6, 7, 8, 11)) == List((2, 6), (4, 8)))
Here's an imperative solution:
scala> def findMatches(divs: List[Int], nums: List[Int]): List[(Int, Int)] = {
| var matches = List.empty[(Int, Int)]
| var remaining = nums
| divs foreach { div =>
| remaining find (_ % div == 0) foreach { n =>
| remaining = remaining filterNot (_ == n)
| matches = matches ::: List(div -> n)
| }
| }
| matches
| }
findMatches: (divs: List[Int], nums: List[Int])List[(Int, Int)]
Notice that I have to update the state of remaining
as well as accumulating matches
. It sounds like a job for scalaz traverse!
My useless working has got me this far:
scala> def findMatches(divs: List[Int], nums: List[Int]): List[(Int, Int)] = {
| divs.traverse[({type l[a] = State[List[Int], a]})#l, Int]( div =>
| state { (rem: List[Int]) => rem.find(_ % div == 0).map(n => rem.filterNot(_ == n) -> List(div -> n)).getOrElse(rem -> List.empty[(Int, Int)]) }
| ) ~> nums
| }
<console>:15: error: type mismatch;
found : List[(Int, Int)]
required: Int
state { (rem: List[Int]) => rem.find(_ % div == 0).map(n => rem.filterNot(_ == n) -> List(div -> n)).getOrElse(rem -> List.empty[(Int, Int)]) }
^
Your code only needs to be slightly modified in order to use State and Traverse:
// using scalaz-seven
import scalaz._
import Scalaz._
def findMatches(divs: List[Int], nums: List[Int]) = {
// the "state" we carry when traversing
case class S(matches: List[(Int, Int)], remaining: List[Int])
// initially there are no found pairs and a full list of nums
val initialState = S(List[(Int, Int)](), nums)
// a function to find a pair (div, num) given the current "state"
// we return a state transition that modifies the state
def find(div: Int) = modify((s: S) =>
s.remaining.find(_ % div == 0).map { (n: Int) =>
S(s.matches :+ div -> n, s.remaining -n)
}.getOrElse(s))
// the traversal, with no type annotation thanks to Scalaz7
// Note that we use `exec` to get the final state
// instead of `eval` that would just give us a List[Unit].
divs.traverseS(find).exec(initialState).matches
}
// List((2,6), (3,9), (4,8))
findMatches(List(2, 3, 4), List(1, 6, 7, 8, 9))
You can also use runTraverseS
to write the traversal a bit differently:
divs.runTraverseS(initialState)(find)._2.matches
I have finally figured this out after much messing about:
scala> def findMatches(divs: List[Int], nums: List[Int]): List[(Int, Int)] = {
| (divs.traverse[({type l[a] = State[List[Int], a]})#l, Option[(Int, Int)]]( div =>
| state { (rem: List[Int]) =>
| rem.find(_ % div == 0).map(n => rem.filterNot(_ == n) -> Some(div -> n)).getOrElse(rem -> none[(Int, Int)])
| }
| ) ! nums).flatten
| }
findMatches: (divs: List[Int], nums: List[Int])List[(Int, Int)]
I think I'll be looking at Eric's answer for more insight into what is actually going on, though.
Exploring Eric's answer using scalaz6
scala> def findMatches2(divs: List[Int], nums: List[Int]): List[(Int, Int)] = {
| case class S(matches: List[(Int, Int)], remaining: List[Int])
| val initialState = S(nil[(Int, Int)], nums)
| def find(div: Int, s: S) = {
| val newState = s.remaining.find(_ % div == 0).map { (n: Int) =>
| S(s.matches :+ div -> n, s.remaining filterNot (_ == n))
| }.getOrElse(s)
| newState -> newState.matches
| }
| val findDivs = (div: Int) => state((s: S) => find(div, s))
| (divs.traverse[({type l[a]=State[S, a]})#l, List[(Int, Int)]](findDivs) ! initialState).join
| }
findMatches2: (divs: List[Int], nums: List[Int])List[(Int, Int)]
scala> findMatches2(List(2, 3, 4), List(1, 6, 7, 8, 9))
res11: List[(Int, Int)] = List((2,6), (2,6), (3,9), (2,6), (3,9), (4,8))
The join
on the List[List[(Int, Int)]]
at the end is causing grief. Instead we can replace the last line with:
(divs.traverse[({type l[a]=State[S, a]})#l, List[(Int, Int)]](findDivs) ~> initialState).matches
In fact, you can do away with the extra output of a state computation altogether and simplify even further:
scala> def findMatches2(divs: List[Int], nums: List[Int]): List[(Int, Int)] = {
| case class S(matches: List[(Int, Int)], remaining: List[Int])
| def find(div: Int, s: S) =
| s.remaining.find(_ % div == 0).map( n => S(s.matches :+ div -> n, s.remaining filterNot (_ == n)) ).getOrElse(s) -> ()
| (divs.traverse[({type l[a]=State[S, a]})#l, Unit](div => state((s: S) => find(div, s))) ~> S(nil[(Int, Int)], nums)).matches
| }
findMatches2: (divs: List[Int], nums: List[Int])List[(Int, Int)]
modify
described above by Apocalisp is also available in scalaz6 and removes the need to explicitly supply the (S, ())
pair (although you do need Unit
in the type lambda):
scala> def findMatches2(divs: List[Int], nums: List[Int]): List[(Int, Int)] = {
| case class S(matches: List[(Int, Int)], remaining: List[Int])
| def find(div: Int) = modify( (s: S) =>
| s.remaining.find(_ % div == 0).map( n => S(s.matches :+ div -> n, s.remaining filterNot (_ == n)) ).getOrElse(s))
| (divs.traverse[({type l[a]=State[S, a]})#l, Unit](div => state(s => find(div)(s))) ~> S(nil, nums)).matches
| }
findMatches2: (divs: List[Int], nums: List[Int])List[(Int, Int)]
scala> findMatches2(List(2, 3, 4), List(1, 6, 7, 8, 9))
res0: List[(Int, Int)] = List((2,6), (3,9), (4,8))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With