Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Functional patterns for better chaining of collect

I often find myself needing to chain collects where I want to do multiple collects in a single traversal. I also would like to return a "remainder" for things that don't match any of the collects.

For example:

sealed trait Animal
case class Cat(name: String) extends Animal
case class Dog(name: String, age: Int) extends Animal

val animals: List[Animal] =
  List(Cat("Bob"), Dog("Spot", 3), Cat("Sally"), Dog("Jim", 11))

// Normal way
val cats: List[Cat]    = animals.collect { case c: Cat => c }
val dogAges: List[Int] = animals.collect { case Dog(_, age) => age }
val rem: List[Animal]  = Nil // No easy way to create this without repeated code

This really isn't great, it requires multiple iterations and there is no reasonable way to calculate the remainder. I could write a very complicated fold to pull this off, but it would be really nasty.

Instead, I usually opt for mutation which is fairly similar to the logic you would have in a fold:

import scala.collection.mutable.ListBuffer

// Ugly, hide the mutation away
val (cats2, dogsAges2, rem2) = {
  // Lose some benefits of type inference
  val cs = ListBuffer[Cat]()
  val da = ListBuffer[Int]()
  val rem = ListBuffer[Animal]()
  // Bad separation of concerns, I have to merge all of my functions
  animals.foreach {
    case c: Cat      => cs += c
    case Dog(_, age) => da += age
    case other       => rem += other
  }
  (cs.toList, da.toList, rem.toList)
}

I don't like this one bit, it has worse type inference and separation of concerns since I have to merge all of the various partial functions. It also requires lots of lines of code.

What I want, are some useful patterns, like a collect that returns the remainder (I grant that partitionMap new in 2.13 does this, but uglier). I also could use some form of pipe or map for operating on parts of tuples. Here are some made up utilities:

implicit class ListSyntax[A](xs: List[A]) {
  import scala.collection.mutable.ListBuffer
  // Collect and return remainder
  // A specialized form of new 2.13 partitionMap
  def collectR[B](pf: PartialFunction[A, B]): (List[B], List[A]) = {
    val rem = new ListBuffer[A]()
    val res = new ListBuffer[B]()
    val f = pf.lift
    for (elt <- xs) {
      f(elt) match {
        case Some(r) => res += r
        case None    => rem += elt
      }
    }
    (res.toList, rem.toList)
  }
}
implicit class Tuple2Syntax[A, B](x: Tuple2[A, B]){
  def chainR[C](f: B => C): Tuple2[A, C] = x.copy(_2 = f(x._2))
}

Now, I can write this in a way that could be done in a single traversal (with a lazy datastructure) and yet follows functional, immutable practice:

// Relatively pretty, can imagine lazy forms using a single iteration
val (cats3, (dogAges3, rem3)) =
  animals.collectR          { case c: Cat => c }
         .chainR(_.collectR { case Dog(_, age) => age })

My question is, are there patterns like this? It smells like the type of thing that would be in a library like Cats, FS2, or ZIO, but I am not sure what it might be called.

Scastie link of code examples: https://scastie.scala-lang.org/Egz78fnGR6KyqlUTNTv9DQ

like image 705
Jack Koenig Avatar asked Aug 06 '20 00:08

Jack Koenig


3 Answers

I wanted to see just how "nasty" a fold() would be.

val (cats
    ,dogAges
    ,rem) = animals.foldRight((List.empty[Cat]
                              ,List.empty[Int]
                              ,List.empty[Animal])) {
  case (c:Cat,   (cs,ds,rs)) => (c::cs, ds, rs)
  case (Dog(_,d),(cs,ds,rs)) => (cs, d::ds, rs)
  case (r,       (cs,ds,rs)) => (cs, ds, r::rs)
}

Eye of the beholder I suppose.

like image 64
jwvh Avatar answered Oct 13 '22 10:10

jwvh


How about defining a couple utility classes to help you with this?

case class ListCollect[A](list: List[A]) {
  def partialCollect[B](f: PartialFunction[A, B]): ChainCollect[List[B], A] = {
    val (cs, rem) = list.partition(f.isDefinedAt)
    new ChainCollect((cs.map(f), rem))
  }
}

case class ChainCollect[A, B](tuple: (A, List[B])) {
  def partialCollect[C](f: PartialFunction[B, C]): ChainCollect[(A, List[C]), B] = {
    val (cs, rem) = tuple._2.partition(f.isDefinedAt)
    ChainCollect(((tuple._1, cs.map(f)), rem))
  }
}

ListCollect is just meant to start the chain, and ChainCollect takes the previous remainder (the second element of the tuple) and tries to apply a PartialFunction to it, creating a new ChainCollect object. I'm not particularly fond of the nested tuples this produces, but you may be able to make it look a bit better if you use Shapeless's HLists.

val ((cats, dogs), rem) = ListCollect(animals)
  .partialCollect { case c: Cat => c }
  .partialCollect { case Dog(_, age) => age }
  .tuple

Scastie


Dotty's *: type makes this a bit easier:

opaque type ChainResult[Prev <: Tuple, Rem] = (Prev, List[Rem])

extension [P <: Tuple, R, N](chainRes: ChainResult[P, R]) {
  def partialCollect(f: PartialFunction[R, N]): ChainResult[List[N] *: P, R] = {
    val (cs, rem) = chainRes._2.partition(f.isDefinedAt)
    (cs.map(f) *: chainRes._1, rem)
  }
}

This does end up in the output being reversed, but it doesn't have that ugly nesting from my previous approach:


val ((owls, dogs, cats), rem) = (EmptyTuple, animals)
  .partialCollect { case c: Cat => c }
  .partialCollect { case Dog(_, age) => age }
  .partialCollect { case Owl(wisdom) => wisdom }

/* more animals */

case class Owl(wisdom: Double) extends Animal
case class Fly(isAnimal: Boolean) extends Animal

val animals: List[Animal] =
  List(Cat("Bob"), Dog("Spot", 3), Cat("Sally"), Dog("Jim", 11), Owl(200), Fly(false))

Scastie

And if you still don't like that, you can always define a few more helper methods to reverse the tuple, add the extension on a List without requiring an EmptyTuple to begin with, etc.

//Add this to the ChainResult extension
def end: Reverse[List[R] *: P] = {
    def revHelp[A <: Tuple, R <: Tuple](acc: A, rest: R): RevHelp[A, R] =
      rest match {
        case EmptyTuple => acc.asInstanceOf[RevHelp[A, R]]
        case h *: t => revHelp(h *: acc, t).asInstanceOf[RevHelp[A, R]]
      }
    revHelp(EmptyTuple, chainRes._2 *: chainRes._1)
  }

//Helpful types for safety
type Reverse[T <: Tuple] = RevHelp[EmptyTuple, T]
type RevHelp[A <: Tuple, R <: Tuple] <: Tuple = R match {
  case EmptyTuple => A
  case h *: t => RevHelp[h *: A, t]
}

And now you can do this:

val (cats, dogs, owls, rem) = (EmptyTuple, animals)
  .partialCollect { case c: Cat => c }
  .partialCollect { case Dog(_, age) => age }
  .partialCollect { case Owl(wisdom) => wisdom }
  .end

Scastie

like image 4
user Avatar answered Oct 13 '22 10:10

user


Since you mentioned cats, I would also add solution using foldMap:

sealed trait Animal
case class Cat(name: String) extends Animal
case class Dog(name: String) extends Animal
case class Snake(name: String) extends Animal

val animals: List[Animal] = List(Cat("Bob"), Dog("Spot"), Cat("Sally"), Dog("Jim"), Snake("Billy"))

val map = animals.foldMap{ //Map(other -> List(Snake(Billy)), cats -> List(Cat(Bob), Cat(Sally)), dogs -> List(Dog(Spot), Dog(Jim)))
  case d: Dog => Map("dogs" -> List(d))
  case c: Cat => Map("cats" -> List(c))
  case o => Map("other" -> List(o))
}

val tuples = animals.foldMap{ //(List(Dog(Spot), Dog(Jim)),List(Cat(Bob), Cat(Sally)),List(Snake(Billy)))
  case d: Dog => (List(d), Nil, Nil)
  case c: Cat => (Nil, List(c), Nil)
  case o => (Nil, Nil, List(o))
}

Arguably it's more succinct than fold version, but it has to combine partial results using monoids, so it won't be as performant.

like image 3
Krzysztof Atłasik Avatar answered Oct 13 '22 09:10

Krzysztof Atłasik