jsonscalacircegeneric-derivation

Generic derivation for ADTs in Scala with a custom representation


I'm paraphrasing a question from the circe Gitter channel here.

Suppose I've got a Scala sealed trait hierarchy (or ADT) like this:

sealed trait Item
case class Cake(flavor: String, height: Int) extends Item
case class Hat(shape: String, material: String, color: String) extends Item

…and I want to be able to map back and forth between this ADT and a JSON representation like the following:

{ "tag": "Cake", "contents": ["cherry", 100] }
{ "tag": "Hat", "contents": ["cowboy", "felt", "black"] }

By default circe's generic derivation uses a different representation:

scala> val item1: Item = Cake("cherry", 100)
item1: Item = Cake(cherry,100)

scala> val item2: Item = Hat("cowboy", "felt", "brown")
item2: Item = Hat(cowboy,felt,brown)

scala> import io.circe.generic.auto._, io.circe.syntax._
import io.circe.generic.auto._
import io.circe.syntax._

scala> item1.asJson.noSpaces
res0: String = {"Cake":{"flavor":"cherry","height":100}}

scala> item2.asJson.noSpaces
res1: String = {"Hat":{"shape":"cowboy","material":"felt","color":"brown"}}

We can get a little closer with circe-generic-extras:

import io.circe.generic.extras.Configuration
import io.circe.generic.extras.auto._

implicit val configuration: Configuration =
   Configuration.default.withDiscriminator("tag")

And then:

scala> item1.asJson.noSpaces
res2: String = {"flavor":"cherry","height":100,"tag":"Cake"}

scala> item2.asJson.noSpaces
res3: String = {"shape":"cowboy","material":"felt","color":"brown","tag":"Hat"}

…but it's still not what we want.

What's the best way to use circe to derive instances like this generically for ADTs in Scala?


Solution

  • Representing case classes as JSON arrays

    The first thing to note is that the circe-shapes module provides instances for Shapeless's HLists that use an array representation like the one we want for our case classes. For example:

    scala> import io.circe.shapes._
    import io.circe.shapes._
    
    scala> import shapeless._
    import shapeless._
    
    scala> ("foo" :: 1 :: List(true, false) :: HNil).asJson.noSpaces
    res4: String = ["foo",1,[true,false]]
    

    …and Shapeless itself provides a generic mapping between case classes and HLists. We can combine these two to get the generic instances we want for case classes:

    import io.circe.{ Decoder, Encoder }
    import io.circe.shapes.HListInstances
    import shapeless.{ Generic, HList }
    
    trait FlatCaseClassCodecs extends HListInstances {
      implicit def encodeCaseClassFlat[A, Repr <: HList](implicit
        gen: Generic.Aux[A, Repr],
        encodeRepr: Encoder[Repr]
      ): Encoder[A] = encodeRepr.contramap(gen.to)
    
      implicit def decodeCaseClassFlat[A, Repr <: HList](implicit
        gen: Generic.Aux[A, Repr],
        decodeRepr: Decoder[Repr]
      ): Decoder[A] = decodeRepr.map(gen.from)
    }
    
    object FlatCaseClassCodecs extends FlatCaseClassCodecs
    

    And then:

    scala> import FlatCaseClassCodecs._
    import FlatCaseClassCodecs._
    
    scala> Cake("cherry", 100).asJson.noSpaces
    res5: String = ["cherry",100]
    
    scala> Hat("cowboy", "felt", "brown").asJson.noSpaces
    res6: String = ["cowboy","felt","brown"]
    

    Note that I'm using io.circe.shapes.HListInstances to bundle up just the instances we need from circe-shapes together with our custom case class instances, in order to minimize the number of things our users have to import (both as a matter of ergonomics and for the sake of keeping down compile times).

    Encoding the generic representation of our ADTs

    That's a good first step, but it doesn't get us the representation we want for Item itself. To do that we need some more complex machinery:

    import io.circe.{ JsonObject, ObjectEncoder }
    import shapeless.{ :+:, CNil, Coproduct, Inl, Inr, Witness }
    import shapeless.labelled.FieldType
    
    trait ReprEncoder[C <: Coproduct] extends ObjectEncoder[C]
    
    object ReprEncoder {
      def wrap[A <: Coproduct](encodeA: ObjectEncoder[A]): ReprEncoder[A] =
        new ReprEncoder[A] {
          def encodeObject(a: A): JsonObject = encodeA.encodeObject(a)
        }
    
      implicit val encodeCNil: ReprEncoder[CNil] = wrap(
        ObjectEncoder.instance[CNil](_ => sys.error("Cannot encode CNil"))
      )
    
      implicit def encodeCCons[K <: Symbol, L, R <: Coproduct](implicit
        witK: Witness.Aux[K],
        encodeL: Encoder[L],
        encodeR: ReprEncoder[R]
      ): ReprEncoder[FieldType[K, L] :+: R] = wrap[FieldType[K, L] :+: R](
        ObjectEncoder.instance {
          case Inl(l) => JsonObject("tag" := witK.value.name, "contents" := (l: L))
          case Inr(r) => encodeR.encodeObject(r)
        }
      )
    }
    

    This tells us how to encode instances of Coproduct, which Shapeless uses as a generic representation of sealed trait hierarchies in Scala. The code may be intimidating at first, but it's a very common pattern, and if you spend much time working with Shapeless you'll recognize that 90% of this code is essentially boilerplate that you see any time you build up instances inductively like this.

    Decoding these coproducts

    The decoding implementation is a little worse, even, but follows the same pattern:

    import io.circe.{ DecodingFailure, HCursor }
    import shapeless.labelled.field
    
    trait ReprDecoder[C <: Coproduct] extends Decoder[C]
    
    object ReprDecoder {
      def wrap[A <: Coproduct](decodeA: Decoder[A]): ReprDecoder[A] =
        new ReprDecoder[A] {
          def apply(c: HCursor): Decoder.Result[A] = decodeA(c)
        }
    
      implicit val decodeCNil: ReprDecoder[CNil] = wrap(
        Decoder.failed(DecodingFailure("CNil", Nil))
      )
    
      implicit def decodeCCons[K <: Symbol, L, R <: Coproduct](implicit
        witK: Witness.Aux[K],
        decodeL: Decoder[L],
        decodeR: ReprDecoder[R]
      ): ReprDecoder[FieldType[K, L] :+: R] = wrap(
        decodeL.prepare(_.downField("contents")).validate(
          _.downField("tag").focus
            .flatMap(_.as[String].right.toOption)
            .contains(witK.value.name),
          witK.value.name
        )
        .map(l => Inl[FieldType[K, L], R](field[K](l)))
        .or(decodeR.map[FieldType[K, L] :+: R](Inr(_)))
      )
    }
    

    In general there will be a little more logic involved in our Decoder implementations, since each decoding step can fail.

    Our ADT representation

    Now we can wrap it all together:

    import shapeless.{ LabelledGeneric, Lazy }
    
    object Derivation extends FlatCaseClassCodecs {
      implicit def encodeAdt[A, Repr <: Coproduct](implicit
        gen: LabelledGeneric.Aux[A, Repr],
        encodeRepr: Lazy[ReprEncoder[Repr]]
      ): ObjectEncoder[A] = encodeRepr.value.contramapObject(gen.to)
    
      implicit def decodeAdt[A, Repr <: Coproduct](implicit
        gen: LabelledGeneric.Aux[A, Repr],
        decodeRepr: Lazy[ReprDecoder[Repr]]
      ): Decoder[A] = decodeRepr.value.map(gen.from)
    }
    

    This looks very similar to the definitions in our FlatCaseClassCodecs above, and the idea is the same: we're defining instances for our data type (either case classes or ADTs) by building on the instances for the generic representations of those data types. Note that I'm extending FlatCaseClassCodecs, again to minimize imports for the user.

    In action

    Now we can use these instances like this:

    scala> import Derivation._
    import Derivation._
    
    scala> item1.asJson.noSpaces
    res7: String = {"tag":"Cake","contents":["cherry",100]}
    
    scala> item2.asJson.noSpaces
    res8: String = {"tag":"Hat","contents":["cowboy","felt","brown"]}
    

    …which is exactly what we wanted. And the best part is that this will work for any sealed trait hierarchy in Scala, no matter how many case classes it has or how many members those case classes have (although compile times will start to hurt once you're into the dozens of either), assuming all of the member types have JSON representations.