Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generic derivation for ADTs in Scala with a custom representation

I'm paraphrasing a question from the circe Gitter channel here.

Suppose I've got a Scala sealed trait hierarchy (or ADT) like this:

sealed trait Item
case class Cake(flavor: String, height: Int) extends Item
case class Hat(shape: String, material: String, color: String) extends Item

…and I want to be able to map back and forth between this ADT and a JSON representation like the following:

{ "tag": "Cake", "contents": ["cherry", 100] }
{ "tag": "Hat", "contents": ["cowboy", "felt", "black"] }

By default circe's generic derivation uses a different representation:

scala> val item1: Item = Cake("cherry", 100)
item1: Item = Cake(cherry,100)

scala> val item2: Item = Hat("cowboy", "felt", "brown")
item2: Item = Hat(cowboy,felt,brown)

scala> import io.circe.generic.auto._, io.circe.syntax._
import io.circe.generic.auto._
import io.circe.syntax._

scala> item1.asJson.noSpaces
res0: String = {"Cake":{"flavor":"cherry","height":100}}

scala> item2.asJson.noSpaces
res1: String = {"Hat":{"shape":"cowboy","material":"felt","color":"brown"}}

We can get a little closer with circe-generic-extras:

import io.circe.generic.extras.Configuration
import io.circe.generic.extras.auto._

implicit val configuration: Configuration =
   Configuration.default.withDiscriminator("tag")

And then:

scala> item1.asJson.noSpaces
res2: String = {"flavor":"cherry","height":100,"tag":"Cake"}

scala> item2.asJson.noSpaces
res3: String = {"shape":"cowboy","material":"felt","color":"brown","tag":"Hat"}

…but it's still not what we want.

What's the best way to use circe to derive instances like this generically for ADTs in Scala?

like image 865
Travis Brown Avatar asked Aug 31 '18 14:08

Travis Brown


1 Answers

Representing case classes as JSON arrays

The first thing to note is that the circe-shapes module provides instances for Shapeless's HLists that use an array representation like the one we want for our case classes. For example:

scala> import io.circe.shapes._
import io.circe.shapes._

scala> import shapeless._
import shapeless._

scala> ("foo" :: 1 :: List(true, false) :: HNil).asJson.noSpaces
res4: String = ["foo",1,[true,false]]

…and Shapeless itself provides a generic mapping between case classes and HLists. We can combine these two to get the generic instances we want for case classes:

import io.circe.{ Decoder, Encoder }
import io.circe.shapes.HListInstances
import shapeless.{ Generic, HList }

trait FlatCaseClassCodecs extends HListInstances {
  implicit def encodeCaseClassFlat[A, Repr <: HList](implicit
    gen: Generic.Aux[A, Repr],
    encodeRepr: Encoder[Repr]
  ): Encoder[A] = encodeRepr.contramap(gen.to)

  implicit def decodeCaseClassFlat[A, Repr <: HList](implicit
    gen: Generic.Aux[A, Repr],
    decodeRepr: Decoder[Repr]
  ): Decoder[A] = decodeRepr.map(gen.from)
}

object FlatCaseClassCodecs extends FlatCaseClassCodecs

And then:

scala> import FlatCaseClassCodecs._
import FlatCaseClassCodecs._

scala> Cake("cherry", 100).asJson.noSpaces
res5: String = ["cherry",100]

scala> Hat("cowboy", "felt", "brown").asJson.noSpaces
res6: String = ["cowboy","felt","brown"]

Note that I'm using io.circe.shapes.HListInstances to bundle up just the instances we need from circe-shapes together with our custom case class instances, in order to minimize the number of things our users have to import (both as a matter of ergonomics and for the sake of keeping down compile times).

Encoding the generic representation of our ADTs

That's a good first step, but it doesn't get us the representation we want for Item itself. To do that we need some more complex machinery:

import io.circe.{ JsonObject, ObjectEncoder }
import shapeless.{ :+:, CNil, Coproduct, Inl, Inr, Witness }
import shapeless.labelled.FieldType

trait ReprEncoder[C <: Coproduct] extends ObjectEncoder[C]

object ReprEncoder {
  def wrap[A <: Coproduct](encodeA: ObjectEncoder[A]): ReprEncoder[A] =
    new ReprEncoder[A] {
      def encodeObject(a: A): JsonObject = encodeA.encodeObject(a)
    }

  implicit val encodeCNil: ReprEncoder[CNil] = wrap(
    ObjectEncoder.instance[CNil](_ => sys.error("Cannot encode CNil"))
  )

  implicit def encodeCCons[K <: Symbol, L, R <: Coproduct](implicit
    witK: Witness.Aux[K],
    encodeL: Encoder[L],
    encodeR: ReprEncoder[R]
  ): ReprEncoder[FieldType[K, L] :+: R] = wrap[FieldType[K, L] :+: R](
    ObjectEncoder.instance {
      case Inl(l) => JsonObject("tag" := witK.value.name, "contents" := (l: L))
      case Inr(r) => encodeR.encodeObject(r)
    }
  )
}

This tells us how to encode instances of Coproduct, which Shapeless uses as a generic representation of sealed trait hierarchies in Scala. The code may be intimidating at first, but it's a very common pattern, and if you spend much time working with Shapeless you'll recognize that 90% of this code is essentially boilerplate that you see any time you build up instances inductively like this.

Decoding these coproducts

The decoding implementation is a little worse, even, but follows the same pattern:

import io.circe.{ DecodingFailure, HCursor }
import shapeless.labelled.field

trait ReprDecoder[C <: Coproduct] extends Decoder[C]

object ReprDecoder {
  def wrap[A <: Coproduct](decodeA: Decoder[A]): ReprDecoder[A] =
    new ReprDecoder[A] {
      def apply(c: HCursor): Decoder.Result[A] = decodeA(c)
    }

  implicit val decodeCNil: ReprDecoder[CNil] = wrap(
    Decoder.failed(DecodingFailure("CNil", Nil))
  )

  implicit def decodeCCons[K <: Symbol, L, R <: Coproduct](implicit
    witK: Witness.Aux[K],
    decodeL: Decoder[L],
    decodeR: ReprDecoder[R]
  ): ReprDecoder[FieldType[K, L] :+: R] = wrap(
    decodeL.prepare(_.downField("contents")).validate(
      _.downField("tag").focus
        .flatMap(_.as[String].right.toOption)
        .contains(witK.value.name),
      witK.value.name
    )
    .map(l => Inl[FieldType[K, L], R](field[K](l)))
    .or(decodeR.map[FieldType[K, L] :+: R](Inr(_)))
  )
}

In general there will be a little more logic involved in our Decoder implementations, since each decoding step can fail.

Our ADT representation

Now we can wrap it all together:

import shapeless.{ LabelledGeneric, Lazy }

object Derivation extends FlatCaseClassCodecs {
  implicit def encodeAdt[A, Repr <: Coproduct](implicit
    gen: LabelledGeneric.Aux[A, Repr],
    encodeRepr: Lazy[ReprEncoder[Repr]]
  ): ObjectEncoder[A] = encodeRepr.value.contramapObject(gen.to)

  implicit def decodeAdt[A, Repr <: Coproduct](implicit
    gen: LabelledGeneric.Aux[A, Repr],
    decodeRepr: Lazy[ReprDecoder[Repr]]
  ): Decoder[A] = decodeRepr.value.map(gen.from)
}

This looks very similar to the definitions in our FlatCaseClassCodecs above, and the idea is the same: we're defining instances for our data type (either case classes or ADTs) by building on the instances for the generic representations of those data types. Note that I'm extending FlatCaseClassCodecs, again to minimize imports for the user.

In action

Now we can use these instances like this:

scala> import Derivation._
import Derivation._

scala> item1.asJson.noSpaces
res7: String = {"tag":"Cake","contents":["cherry",100]}

scala> item2.asJson.noSpaces
res8: String = {"tag":"Hat","contents":["cowboy","felt","brown"]}

…which is exactly what we wanted. And the best part is that this will work for any sealed trait hierarchy in Scala, no matter how many case classes it has or how many members those case classes have (although compile times will start to hurt once you're into the dozens of either), assuming all of the member types have JSON representations.

like image 136
Travis Brown Avatar answered Nov 05 '22 04:11

Travis Brown