I'm using the State monad from the Scala Cats library to compose imperative sequences of state transitions in a functional manner.
My actual use-case is quite complicated, so to simplify matters, consider the following minimal problem: there is a Counter
state that keeps a count value that may be incremented or decremented; however it is an error if the count becomes negative or overflows. In the event that an error is encountered, I need to preserve the state at the time of the error and effectively stop processing subsequent state transitions.
I'm using the return value of each state transition to report any errors, using the type Try[Unit]
. An operation that completes successfully then returns the new state plus the value Success(())
, while a failure returns the existing state plus an exception wrapped in Failure
.
Note: Clearly, I could just throw an exception when I encounter an error. However, this would violate referential transparency and would also require that I do some extra work to store the counter state in the thrown exception. I also discounted using a Try[Counter]
as the state type (instead of just Counter
), since I cannot use this to track both the failure and the failed state. One option I haven't explored is using a (Counter, Try[Unit])
tuple as the state, because that just seems too cumbersome, but I'm open to suggestions.
import cats.data.State
import scala.util.{Failure, Success, Try}
// State being maintained: an immutable counter.
final case class Counter(count: Int)
// Type for state transition operations.
type Transition[M] = State[Counter, Try[M]]
// Operation to increment a counter.
val increment: Transition[Unit] = State {c =>
// If the count is at its maximum, incrementing it must fail.
if(c.count == Int.MaxValue) {
(c, Failure(new ArithmeticException("Attempt to overflow counter failed")))
}
// Otherwise, increment the count and indicate success.
else (c.copy(count = c.count + 1), Success(()))
}
// Operation to decrement a counter.
val decrement: Transition[Unit] = State {c =>
// If the count is zero, decrementing it must fail.
if(c.count == 0) {
(c, Failure(new ArithmeticException("Attempt to make count negative failed")))
}
// Otherwise, decrement the count and indicate success.
else (c.copy(count = c.count - 1), Success(()))
}
However, I'm struggling to determine the best approach to stringing transitions together, while dealing with any failures in the desired manner. (If you prefer, a more general statement of my problem is that I need to perform subsequent transitions conditionally upon a returned value of the previous transition.)
For example, the following set of transitions might fail at the first, third or fourth step (but let's assume it can fail at the second too), depending upon the counter's starting state, but it will still attempt to perform the next step unconditionally:
val counterManip: Transition[Unit] = for {
_ <- decrement
_ <- increment
_ <- increment
r <- increment
} yield r
If I run this code with an initial counter value of 0, clearly what I will get is new counter value of 3 and a Success(())
, since that is the result of the last step:
scala> counterManip.run(Counter(0)).value
res0: (Counter, scala.util.Try[Unit]) = (Counter(3),Success(()))
but what I want is to get the initial counter state (the state that failed the decrement
operation) and an ArithmeticException
wrapped in Failure
, since the first step failed.
The only solution I've been able to come up so far is hideously complicated, repetitive and error prone:
val counterManip: Transition[Unit] = State {s0 =>
val r1 = decrement.run(s0).value
if(r1._2.isFailure) r1
else {
val r2 = increment.run(r1._1).value
if(r2._2.isFailure) r2
else {
val r3 = increment.run(r2._1).value
if(r3._2.isFailure) r3
else increment.run(r3._1).value
}
}
}
which gives the correct result:
scala> counterMap.run(Counter(0)).value
res1: (Counter, scala.util.Try[Unit]) = (Counter(0),Failure(java.lang.ArithmeticException: Attempt to make count negative failed))
Update
I came up with the untilFailure
method, below, for running sequences of transitions until they're complete or until an error occurs (whichever comes first). I'm warming to it as it's simple and elegant to use.
However, I'm still curious whether there's an elegant way to chain the transitions together in a manner more directly. (For example, if the transitions were just regular functions that returned Try[T]
—and had no state—then we could chain calls together using flatMap
, allowing construction of a for
expression which would pass the result of successful transitions to the next transition.)
Can you suggest a better approach?
Doh! I don't know why this didn't occur to me sooner. Sometimes just explaining your problem in simpler terms forces you to look at it afresh, I guess...
One possibility is to handle sequences of transitions, so that the next task is only undertaken if the current task succeeds.
// Run a sequence of transitions, until one fails.
def untilFailure[M](ts: List[Transition[M]]): Transition[M] = State {s =>
ts match {
// If we have an empty list, that's an error. (Cannot report a success value.)
case Nil => (s, Failure(new RuntimeException("Empty transition sequence")))
// If there's only one transition left, perform it and return the result.
case t :: Nil => t.run(s).value
// Otherwise, we have more than one transition remaining.
//
// Run the next transition. If it fails, report the failure, otherwise repeat
// for the tail.
case t :: tt => {
val r = t.run(s).value
if(r._2.isFailure) r
else untilFailure(tt).run(r._1).value
}
}
}
We can then implement counterManip
as a sequence.
val counterManip: Transition[Unit] = for {
r <- untilFailure(List(decrement, increment, increment, increment))
} yield r
which gives the correct results:
scala> counterManip.run(Counter(0)).value
res0: (Counter, scala.util.Try[Unit]) = (Counter(0),Failure(java.lang.ArithmeticException: Attempt to make count negative failed))
scala> counterManip.run(Counter(1)).value
res1: (Counter, scala.util.Try[Unit]) = (Counter(3),Success(()))
scala> counterManip.run(Counter(Int.MaxValue - 2)).value
res2: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Success(()))
scala> counterManip.run(Counter(Int.MaxValue - 1)).value
res3: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))
scala> counterManip.run(Counter(Int.MaxValue)).value
res4: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))
The downside is that all of the transitions need to have a return value in common (unless you're OK with Any
result).
From what I understand, your computation has two states, which you can define as an ADT
sealed trait CompState[A]
case class Ok[A](value: A) extends CompState[A]
case class Err[A](lastValue: A, cause: Exception) extends CompState[A]
The next step you can take is to define an update
method for CompState
, to encapsulate your logic of what should happen when chaining the computations.
def update(f: A => A): CompState[A] = this match {
case Ok(a) =>
try Ok(f(a))
catch { case e: Exception => Err(a, e) }
case Err(a, e) => Err(a, e)
}
From there, redefine
type Transition[M] = State[CompState[Counter], M]
// Operation to increment a counter.
// note: using `State.modify` instead of `.apply`
val increment: Transition[Unit] = State.modify { cs =>
// use the new `update` method to take advantage of your chaining semantics
cs update{ c =>
// If the count is at its maximum, incrementing it must fail.
if(c.count == Int.MaxValue) {
throw new ArithmeticException("Attempt to overflow counter failed")
}
// Otherwise, increment the count and indicate success.
else c.copy(count = c.count + 1)
}
}
// Operation to decrement a counter.
val decrement: Transition[Unit] = State.modify { cs =>
cs update { c =>
// If the count is zero, decrementing it must fail.
if(c.count == 0) {
throw new ArithmeticException("Attempt to make count negative failed")
}
// Otherwise, decrement the count and indicate success.
else c.copy(count = c.count - 1)
}
}
Note that in the updated increment/decrement transitions above, I used State.modify
, which changes the state, but does not generate a result. It looks like the "idiomatic" way to obtain the current state at the end of your transitions is to use State.get
, i.e.
val counterManip: State[CompState[Counter], CompState[Counter]] = for {
_ <- decrement
_ <- increment
_ <- increment
_ <- increment
r <- State.get
} yield r
And you can run this and discard the final state using the runA
helper, i.e.
counterManip.runA(Ok(Counter(0))).value
// Err(Counter(0),java.lang.ArithmeticException: Attempt to make count negative failed)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With