scalatypeclassimplicitshapelessgeneric-derivation

Shapeless - How to derive LabelledGeneric for Coproduct


I'm trying to generate LabelledGeneric for Coproduct, so that it can be used instead of typical sealed trait hierarchy. So far I was able to do it by explicit specification of labels for DefaultSymbolicLabelling, but I feel it should be possible to derive it automatically from coproduct's type members.

/**
 * So far I found no way to derive `L` and `l` from `C`.
 */
object ShapelessLabelledGenericForCoproduct extends App {
  trait Base // not sealed!

  case class Case1(a: Int) extends Base

  case class Case2(a: String) extends Base

  case class Case3(b: Boolean) extends Base

  object Base {
    type C = Case1 :+: Case2 :+: Case3 :+: CNil

    type L = (Symbol @@ "Case1") :: (Symbol @@ "Case2") :: (Symbol @@ "Case3") :: shapeless.HNil
    val l: L = tag["Case1"](Symbol("Case1")) :: tag["Case2"](Symbol("Case2")) :: tag["Case3"](Symbol("Case3")) :: HNil

    implicit def myGeneric: Generic.Aux[Base, C] = Generic.instance[Base, C](
      v => Coproduct.runtimeInject[C](v).get,
      v => Coproduct.unsafeGet(v).asInstanceOf[Base]
    )

    implicit def mySymbolicLabelling: DefaultSymbolicLabelling.Aux[Base, L] = DefaultSymbolicLabelling.instance[Base, L](l)
  }

  val lgen = LabelledGeneric[Base]
  val repr = lgen.to(Case1(123))
  println(lgen.from(repr))
}

See code below with sealed trait; in general I'd like to achieve similar behavior, just without sealing the trait.

object ShapelessLabelledGenericForSealedTrait extends App {
  sealed trait Base

  case class Case1(a: Int) extends Base

  case class Case2(a: String) extends Base

  case class Case3(b: Boolean) extends Base

  val lgen = LabelledGeneric[Base]
  val repr = lgen.to(Case1(123))
  println(lgen.from(repr))
}

Any hints? Looked through shapeless macros, but so far I found nothing useful...

m.


Solution

  • For a not sealed trait, the instances of Generic/LabelledGeneric defined in Shapeless can't work.

    All such macros are using .knownDirectSubclasses. It works only for a sealed trait.

    Scala reflection: knownDirectSubclasses only works for sealed traits?

    For a not sealed trait I can always add inheritors of Base in a different file (case class Case4() extends Base) or even at runtime (toolbox.define(q"case class Case4() extends Base")).

    If you're interested only in inheritors defined in the current file then maybe you can avoid usage of .knownDirectSubclasses and write a macro traversing the AST of current file and looking for the inheritors.


    So far I found no way to derive L and l from C.

    It's not hard

    import scala.language.experimental.macros
    import scala.reflect.macros.whitebox
    
    trait ToName[A] {
      type Out <: String with Singleton
    }
    object ToName {
      type Aux[A, Out0 <: String with Singleton] = ToName[A] { type Out = Out0 }
    
      implicit def mkToName[A, Out <: String with Singleton]: Aux[A, Out] = macro mkToNameImpl[A]
    
      def mkToNameImpl[A: c.WeakTypeTag](c: whitebox.Context): c.Tree = {
        import c.universe._
        val A = weakTypeOf[A]
        q"""
          new ToName[$A] {
            type Out = ${A.typeSymbol.name.toString}
          }
        """
      }
    }
    
    implicitly[ToName.Aux[Case1, "Case1"]] // compiles
    
    import shapeless.ops.coproduct.ToHList
    import shapeless.tag.@@
    import shapeless.{:+:, ::, CNil, HList, HNil, Poly0, Poly1, Witness, tag, the}
    
    object toNamePoly extends Poly1 {
      implicit def cse[A <: Base, S <: String with Singleton](implicit
        toName: ToName.Aux[A, S],
        witness: Witness.Aux[S],
        // valueOf: ValueOf[S],
      ): Case.Aux[A, Symbol @@ S] = at(_ => tag[S](Symbol(witness/*valueOf*/.value)))
    }
    
    object nullPoly extends Poly0 {
      implicit def default[A]: Case0[A] = at(null.asInstanceOf[A])
    }
    
    val res = HList.fillWith[the.`ToHList[C]`.Out](nullPoly).map(toNamePoly)
    
    res: L // compiles
    res == l // true
    

    So you can derive DefaultSymbolicLabelling as follows

    import shapeless.ops.coproduct.ToHList
    import shapeless.ops.hlist.{FillWith, Mapper}
    
    implicit def mySymbolicLabelling[L <: HList](implicit
      toHList: ToHList.Aux[C, L],
      fillWith: FillWith[nullPoly.type, L],
      mapper: Mapper[toNamePoly.type, L],
    ): DefaultSymbolicLabelling.Aux[Base, mapper.Out] =
      DefaultSymbolicLabelling.instance[Base, mapper.Out](mapper(fillWith()))
    

    Here is the code with traversing. I'm introducing type class KnownSubclasses

    import shapeless.Coproduct
    import scala.collection.mutable
    import scala.language.experimental.macros
    import scala.reflect.macros.whitebox
    
    trait KnownSubclasses[A] {
      type Out <: Coproduct
    }
    object KnownSubclasses {
      type Aux[A, Out0 <: Coproduct] = KnownSubclasses[A] { type Out = Out0 }
    
      implicit def mkKnownSubclasses[A, Out <: Coproduct]: Aux[A, Out] = macro mkKnownSubclassesImpl[A]
    
      def mkKnownSubclassesImpl[A: c.WeakTypeTag](c: whitebox.Context): c.Tree = {
        import c.universe._
        val A = weakTypeOf[A]
    
        var children = mutable.Seq[Type]()
    
        // subclasses of A
        val traverser = new Traverser {
          override def traverse(tree: Tree): Unit = {
            tree match {
              case _: ClassDef =>
                val tpe = tree.symbol.asClass.toType
                if (tpe <:< A && !(tpe =:= A)) children :+= tpe
              case _ =>
            }
    
            super.traverse(tree)
          }
        }
    
    //  def getType(t: Tree): Type = {
    //    val withoutArgs = t match {
    //      case q"${t1@tq"$_[..$_]"}(...$_)" => t1
    //      case _ => t
    //    }
    //    c.typecheck(tq"$withoutArgs", mode = c.TYPEmode).tpe
    //  }
    //
    //  // direct subclasses of A
    //  val traverser = new Traverser {
    //    override def traverse(tree: Tree): Unit = {
    //      tree match {
    //        case q"$_ class $_[..$_] $_(...$_) extends { ..$_ } with ..$parents { $_ => ..$_ }"
    //          if parents.exists(getType(_) =:= A) =>
    //            children :+= tree.symbol.asClass.toType
    //        case _ =>
    //      }
    //
    //      super.traverse(tree)
    //    }
    //  }
    
        c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))
    
        val coprod = children.foldRight[Tree](tq"_root_.shapeless.CNil")((child, copr) => tq"_root_.shapeless.:+:[$child, $copr]")
    
        q"""
          new KnownSubclasses[$A] {
            type Out = $coprod
          }
        """
      }
    }
    
    implicitly[KnownSubclasses.Aux[Base, Case1 :+: Case2 :+: Case3 :+: CNil]] // compiles
    

    (This implementation doesn't work for generic trait and classes.)

    So you can derive Generic and DefaultSymbolicLabelling (and therefore LabelledGeneric) as follows

    import shapeless.ops.coproduct.{RuntimeInject, ToHList}
    import shapeless.ops.hlist.{FillWith, Mapper}
    
    implicit def myGeneric[C <: Coproduct](implicit
      knownSubclasses: KnownSubclasses.Aux[Base, C],
      runtimeInject: RuntimeInject[C]
    ): Generic.Aux[Base, C] = Generic.instance[Base, C](
      v => Coproduct.runtimeInject[C](v).get,
      v => Coproduct.unsafeGet(v).asInstanceOf[Base]
    )
    
    implicit def mySymbolicLabelling[C <: Coproduct, L <: HList](implicit
      knownSubclasses: KnownSubclasses.Aux[Base, C],
      toHList: ToHList.Aux[C, L],
      fillWith: FillWith[nullPoly.type, L],
      mapper: Mapper[toNamePoly.type, L],
    ): DefaultSymbolicLabelling.Aux[Base, mapper.Out] =
      DefaultSymbolicLabelling.instance[Base, mapper.Out](mapper(fillWith()))