scalashapelesscoproduct

Shapeless: Iterate over the types in a Coproduct


I want to do something really simple, but I'm struggling to craft the correct search or just understanding some of the solutions I've seen.

Given a method that takes a generic type parameter which is a Coproduct;

def apply[T <: Coproduct] = {
  ...
}

How can I iterate over the types that form the coproduct? Specifically, for each type that is a case class, I'd like to recursively examine each field and build up a map with all the information.

Currently I'm working around this using a builder pattern, which I'll post here in case it's useful to others;

class ThingMaker[Entities <: Coproduct] private {
  def doThings(item: Entities): Set[Fact] = {
    ...
  }

def register[A <: Product with Serializable]: ThingMaker[A :+: Entities] = {
    // useful work can be done here on a per type basis
    new ThingMaker[A :+: Entities]
  }
}

object ThingMaker {
  def register[A <: Product with Serializable]: ThingMaker[A :+: CNil] = {
    // useful work can be done here on a per type basis
    new ThingMaker[A :+: CNil]
  }
}

Solution

  • If you just want to inspect values, you can simply pattern match on a coproduct like on any other value...

    def apply[T <: Coproduct](co: T): Any = co match {
      case Inl(MyCaseClass(a, b, c)) => ???
      ...
    }
    

    ...but if you want to be more precise than that, for instance to have a return type that depends on the input, or to inspect the types inside this coproduct to summon implicits, then you can write the exact same pattern matching expression using a type class and several implicit definition:

    trait MyFunction[T <: Coproduct] {
      type Out
      def apply(co: T): Out
    }
    
    object MyFunction {
      // case Inl(MyCaseClass(a, b, c)) =>
      implicit val case1 = new MyFunction[Inl[MyCaseClass]] {
        type Out = Nothing
        def apply(co: Inl[MyCaseClass]): Out = ???
      }
    
      // ...
    }
    

    In general, when you want iterate over all types of a coproduct, you will always follow the same tail recursive structure. As a function:

    def iterate[T <: Coproduct](co: T): Any = co match {
      case Inr(head: Any)       => println(v)
      case Inl(tail: Coproduct) => iterate(tail)
      case CNil                 => ???
    }
    

    Or as a "dependently typed function":

    trait Iterate[T <: Coproduct]
    object Iterate {
      implicit def caseCNil = new Iterate[CNil] {...}
      implicit def caseCCons[H, T <: Coproduct](implicit rec: Iterate[T]) =
        new Iterate[H :+: T] {...}
    }
    

    You can for instance ontain the name of each type in a coproduct using an addition ClassTag implicit:

    trait Iterate[T <: Coproduct] { def types: List[String] }
    
    object Iterate {
      implicit def caseCNil = new Iterate[CNil] {
        def types: List[String] = Nil
      }
    
      implicit def caseCCons[H, T <: Coproduct]
        (implicit
          rec: Iterate[T],
          ct: reflect.ClassTag[H]
        ) =
          new Iterate[H :+: T] {
            def types: List[String] = ct.runtimeClass.getName :: rec.types
          }
    }
    
    implicitly[Iterate[Int :+: String :+: CNil]].types // List(int, java.lang.String)
    

    Because of the way Scala lets you influence implicit priority, it's actually possible to translate any recursive function with pattern matching into this "dependently typed function" pattern. This is unlike Haskell where such function can only be written if call cases of the match expression are provably non-overlapping.