Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cats: Non tail recursive tailRecM method for Monads

In cats, when a Monad is created using Monad trait, an implementation for method tailRecM should be provided.

I have a scenario below that I found impossible to provide a tail recursive implementation of tailRecM

  sealed trait Tree[+A]

  final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

  final case class Leaf[A](value: A) extends Tree[A]

  implicit val treeMonad = new Monad[Tree] {

    override def pure[A](value: A): Tree[A] = Leaf(value)

    override def flatMap[A, B](initial: Tree[A])(func: A => Tree[B]): Tree[B] =
      initial match {
        case Branch(l, r) => Branch(flatMap(l)(func), flatMap(r)(func))
        case Leaf(value) => func(value)
      }

    //@tailrec
    override def tailRecM[A, B](a: A)(func: (A) => Tree[Either[A, B]]): Tree[B] = {
      func(a) match {
        case Branch(l, r) =>
          Branch(
            flatMap(l) {
              case Right(l) => pure(l)
              case Left(l) => tailRecM(l)(func)
            },
            flatMap(r){
              case Right(r) => pure(r)
              case Left(r) => tailRecM(r)(func)
            }
          )

        case Leaf(Left(value)) => tailRecM(value)(func)

        case Leaf(Right(value)) => Leaf(value)
      }
    }
  }

1) According to the above example, how this tailRecM method can be used for optimizing flatMap method call? Does the implementation of the flatMap method is overridden/modified by tailRecM at the compile time ?

2) If the tailRecM is not tail recursive as above, will it still be efficient than using the original flatMap method ?

Please share your thoughts.

like image 490
tharindu_DG Avatar asked Jun 12 '17 16:06

tharindu_DG


2 Answers

Sometimes there is a way to replace a call stack with explicit list.

Here toVisit keeps track of branches that are waiting to be processed.

And toCollect keeps branches that are waiting to be merged until corresponding branch is finished processed.

override def tailRecM[A, B](a: A)(f: (A) => Tree[Either[A, B]]): Tree[B] = {
  @tailrec
  def go(toVisit: List[Tree[Either[A, B]]],
         toCollect: List[Tree[B]]): List[Tree[B]] = toVisit match {
    case (tree :: tail) =>
      tree match {
        case Branch(l, r) =>
          l match {
            case Branch(_, _) => go(l :: r :: tail, toCollect)
            case Leaf(Left(a)) => go(f(a) :: r :: tail, toCollect)
            case Leaf(Right(b)) => go(r :: tail, pure(b) +: toCollect)
          }
        case Leaf(Left(a)) => go(f(a) :: tail, toCollect)
        case Leaf(Right(b)) =>
          go(tail,
             if (toCollect.isEmpty) pure(b) +: toCollect
             else Branch(toCollect.head, pure(b)) :: toCollect.tail)
      }
    case Nil => toCollect
  }

  go(f(a) :: Nil, Nil).head
}

From cats ticket why to use tailRecM

tailRecM won't blow the stack (like almost every JVM program it may OOM), for any of the Monads in cats.

and then

Without tailRecM (or recursive flatMap) being safe, libraries like iteratee.io can't safely be written since they require monadic recursion.

and another ticket states that clients of cats.Monad should be aware that some monads don't have stacksafe tailRecM

tailRecM can still be used by those that are trying to get stack safety, so long as they understand that certain monads will not be able to give it to them

like image 156
Nazarii Bardiuk Avatar answered Sep 18 '22 12:09

Nazarii Bardiuk


Relation between tailRecM and flatMap

To answer you first question, the following code is part of FlatMapLaws.scala, from cats-laws. It tests consistency between flatMap and tailRecM methods.

/**
 * It is possible to implement flatMap from tailRecM and map
 * and it should agree with the flatMap implementation.
 */
def flatMapFromTailRecMConsistency[A, B](fa: F[A], fn: A => F[B]): IsEq[F[B]] = {
  val tailRecMFlatMap = F.tailRecM[Option[A], B](Option.empty[A]) {
    case None => F.map(fa) { a => Left(Some(a)) }
    case Some(a) => F.map(fn(a)) { b => Right(b) }
  }

  F.flatMap(fa)(fn) <-> tailRecMFlatMap
}

This shows how to implement a flatMap from tailRecM and implicitly suggests that the compiler will not do such thing automatically. It's up to the user of the Monad to decide when it makes sense to use tailRecM over flatMap.

This blog has nice scala examples to explain when tailRecM comes in useful. It follows the PureScript article by Phil Freeman, which originally introduced the method.

It explains the downsides in using flatMap for monadic composition:

This characteristic of Scala limits the usefulness of monadic composition where flatMap can call monadic function f, which then can call flatMap etc..

In contrast with a tailRecM-based implementation:

This guarantees greater safety on the user of FlatMap typeclass, but it would mean that each the implementers of the instances would need to provide a safe tailRecM.

Many of the provided methods in cats leverage monadic composition. So, even if you don't use it directly, implementing tailRecM allows for more efficient composition with other monads.

Implmentation for tree

In a different answer, @nazarii-bardiuk provides an implementation of tailRecM which is tail recursive, but does not pass the flatMap/tailRecM consistency test mentioned above. The tree structure is not properly rebuilt after recursion. A fixed version below:

def tailRecM[A, B](arg: A)(func: A => Tree[Either[A, B]]): Tree[B] = {
  @tailrec
  def loop(toVisit: List[Tree[Either[A, B]]], 
           toCollect: List[Option[Tree[B]]]): List[Tree[B]] =
    toVisit match {
      case Branch(l, r) :: next =>
        loop(l :: r :: next, None :: toCollect)

      case Leaf(Left(value)) :: next =>
        loop(func(value) :: next, toCollect)

      case Leaf(Right(value)) :: next =>
        loop(next, Some(pure(value)) :: toCollect)

      case Nil =>
        toCollect.foldLeft(Nil: List[Tree[B]]) { (acc, maybeTree) =>
          maybeTree.map(_ :: acc).getOrElse {
            val left :: right :: tail = acc
            branch(left, right) :: tail
          }
        }
    }

  loop(List(func(arg)), Nil).head
}

(gist with test)

You're probably aware, but your example (as well as the answer by @nazarii-bardiuk) is used in the book Scala with Cats by Noel Welsh and Dave Gurnell (highly recommended).

like image 29
hjmeijer Avatar answered Sep 20 '22 12:09

hjmeijer