scalascala-macrosscala-3

Scala 3 constructor inheritance with macros


Every class implementing a trait must declare a constructor that sets trait's fields:

sealed trait WithPayload:
    def description: String
    def payload1: Int
    def payload2: Long

// All WithPayload's fields have to be listed.
final case class Foo(
    override val payload1: Int,
    override val payload2: Long
) extends WithPayload:
    override def description = "foo"

// All WithPayload's fields have to be listed again.
final case class Bar(
    override val payload1: Int,
    override val payload2: Long
) extends WithPayload:
    override def description = "bar"

Is there a way to get rid of repeated constructor declarations with a macro, kinda like

#define EXTENDS_WITH_PAYLOAD ( \
    override val payload1: Int, \
    override val payload2: Long \
) extends WithPayload

and then:

final case class Foo EXTENDS_WITH_PAYLOAD:
    override def description = "foo"

final case class Bar EXTENDS_WITH_PAYLOAD:
    override def description = "bar"

Solution

  • import scala.annotation.{StaticAnnotation, compileTimeOnly}
    import scala.language.experimental.macros
    import scala.reflect.macros.blackbox
    
    @compileTimeOnly("enable macro annotations")
    class extendsWithPayload extends StaticAnnotation {
      def macroTransform(annottees: Any*): Any = macro ExtendsWithPayloadMacros.macroTransformImpl
    }
    
    object ExtendsWithPayloadMacros {
      def macroTransformImpl(c: blackbox.Context)(annottees: c.Tree*): c.Tree = {
        import c.universe._
        annottees match {
          case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: tail =>
            val parents1 = parents :+ tq"WithPayload"
            val newParams = Seq(q"override val payload1: Int", q"override val payload2: Long")
            val paramss1 = paramss match {
              case Nil => Seq(newParams)
              case params :: paramss1 => (params ++ newParams) :: paramss1
            }
            q"""
              $mods class $tpname[..$tparams] $ctorMods(...$paramss1) extends { ..$earlydefns } with ..$parents1 { $self =>
                ..$stats
              }
    
              ..$tail
            """
        }
      }
    }
    
    sealed trait WithPayload {
      def description: String
      def payload1: Int
      def payload2: Long
    }
    
    @extendsWithPayload
    final case class Foo() {
      override def description = "foo"
    }
    
    @extendsWithPayload
    final case class Bar() {
      override def description = "bar"
    }
    
    //final case class Foo extends WithPayload with scala.Product with scala.Serializable {
    //    override <caseaccessor> <paramaccessor> val payload1: Int = _;
    //    override <caseaccessor> <paramaccessor> val payload2: Long = _;
    //    def <init>(payload1: Int, payload2: Long) = {
    //      super.<init>();
    //      ()
    //    };
    //    override def description = "foo"
    //  };
    //  ()
    //}
    //final case class Bar extends WithPayload with scala.Product with scala.Serializable {
    //    override <caseaccessor> <paramaccessor> val payload1: Int = _;
    //    override <caseaccessor> <paramaccessor> val payload2: Long = _;
    //    def <init>(payload1: Int, payload2: Long) = {
    //      super.<init>();
    //      ()
    //    };
    //    override def description = "bar"
    //  };
    //  ()
    //}
    

    Macro Annotations in Scala 3 (answer)

    How to generate a class in Dotty with macro? (answer)

    Scala 3 macro to create enum

    How to generate parameterless constructor at compile time using scala 3 macro?

    What can you do with MacroAnnotaiton that you cannot do with Macros in Scala 3?

    scalaVersion := "3.3.0-RC4"
    
    import scala.annotation.{MacroAnnotation, experimental}
    import scala.collection.mutable
    import scala.quoted.*
    
    /*sealed*/ trait WithPayload:
      def description: String
      def payload1: Int
      def payload2: Long
    
    @experimental
    class extendsWithPayload extends MacroAnnotation:
      def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
        import quotes.reflect.*
        tree match
          case ClassDef(className, ctr, parents, self, body) =>
            val res = List(ClassDef.copy(tree)(className, ctr, parents :+ TypeTree.of[WithPayload], self, body))
            println(res.map(_.show))
            res
    
    @extendsWithPayload @experimental
    final case class Foo():
      override def description = "foo" // method description overrides nothing
    
    @extendsWithPayload @experimental
    final case class Bar():
      override def description = "bar" // method description overrides nothing
    
    summon[Foo <:< WithPayload] // Cannot prove that Foo <:< WithPayload
    val foo = new Foo()
    foo: WithPayload // Found: (foo: Foo), Required: WithPayload
    
    //List(@scala.annotation.experimental @Macros.extendsWithPayload final case class Foo() extends Macros.WithPayload {
    //  override def hashCode(): scala.Int = scala.runtime.ScalaRunTime._hashCode(Foo.this)
    //  override def equals(x$0: scala.Any): scala.Boolean = Foo.this.eq(x$0.$asInstanceOf$[java.lang.Object]).||(x$0 match {
    //    case x$0: App.Foo @scala.unchecked =>
    //      true
    //    case _ =>
    //      false
    //  })
    //  override def toString(): java.lang.String = scala.runtime.ScalaRunTime._toString(Foo.this)
    //  override def canEqual(that: scala.Any): scala.Boolean = that.isInstanceOf[App.Foo @scala.unchecked]
    //  override def productArity: scala.Int = 0
    //  override def productPrefix: scala.Predef.String = "Foo"
    //  override def productElement(n: scala.Int): scala.Any = n match {
    //    case _ =>
    //      throw new java.lang.IndexOutOfBoundsException(n.toString())
    //  }
    //  override def description: java.lang.String = "foo"
    //})
    //List(@scala.annotation.experimental @Macros.extendsWithPayload final case class Bar() extends Macros.WithPayload {
    //  override def hashCode(): scala.Int = scala.runtime.ScalaRunTime._hashCode(Bar.this)
    //  override def equals(x$0: scala.Any): scala.Boolean = Bar.this.eq(x$0.$asInstanceOf$[java.lang.Object]).||(x$0 match {
    //    case x$0: App.Bar @scala.unchecked =>
    //      true
    //    case _ =>
    //      false
    //  })
    //  override def toString(): java.lang.String = scala.runtime.ScalaRunTime._toString(Bar.this)
    //  override def canEqual(that: scala.Any): scala.Boolean = that.isInstanceOf[App.Bar @scala.unchecked]
    //  override def productArity: scala.Int = 0
    //  override def productPrefix: scala.Predef.String = "Bar"
    //  override def productElement(n: scala.Int): scala.Any = n match {
    //    case _ =>
    //      throw new java.lang.IndexOutOfBoundsException(n.toString())
    //  }
    //  override def description: java.lang.String = "bar"
    //})
    

    Macro annotation to override toString of Scala function (answer)

    In Java or Scala, is there a way to add a callback to an exception so that when the exception is caught, the callback is invoked? (answer)

    Scala conditional compilation

    How to merge multiple imports in scala?

    Scalac 2.13 compiling large auto generated scala file: Method too large

    project/build.sbt

    libraryDependencies ++= Seq(
      "org.scalameta" %% "scalameta" % "4.7.7"
    )
    

    build.sbt

    ThisBuild / scalaVersion := "3.2.2"
    
    lazy val common = project
    
    lazy val before = project
      .dependsOn(common)
    
    lazy val after = project
    //.dependsOn(common)
      .settings(
        Compile / unmanagedSourceDirectories += (Compile / sourceManaged).value
      )
    
    lazy val transform = taskKey[Unit]("Transform sources")
    
    transform := {
      val inputDir  = (before / Compile / scalaSource).value
      val outputDir = (after / Compile / sourceManaged).value
      Generator.gen(inputDir, outputDir)
    }
    

    project/Generator.scala

    import sbt.*
    
    object Generator {
      val ALL: Seq[String] = Seq()
    
      def isAll(filesToTransform: Seq[String]): Boolean = filesToTransform.isEmpty
    
      def gen(
               inputDir: File,
               outputDir: File,
               filesToTransform: Seq[String] = ALL,
             ): Unit = {
        val finder: PathFinder = inputDir ** "*.scala"
        val scalametaTransformer = new AnnotationProcessor()
    
        for (inputFile <- finder.get) yield {
          val inputFileName = inputFile.name
          val inputStr = IO.read(inputFile)
          val transform: String => String =
            if (isAll(filesToTransform) || filesToTransform.contains(inputFileName))
              (scalametaTransformer(_: String))
            else identity
          val outputStr = transform(inputStr)
          val outputFile = outputDir / inputFile.relativeTo(inputDir).get.toString
          IO.write(outputFile, outputStr)
        }
      }
    }
    

    project/AnnotationProcessor.scala

    import scala.meta.*
    
    class AnnotationProcessor extends TreeTransformer {
      val isExtendsWithPayload: Mod => Boolean = { case mod"@extendsWithPayload" => true; case _ => false }
    
      override def apply(tree: Tree): Tree = {
        val tree1 = tree match {
          case q"..$mods class $tname[..$tparams] ..$ctorMods (...$paramss) $template" if mods.exists(isExtendsWithPayload) =>
            val mods1 = mods.filterNot(isExtendsWithPayload)
            template match {
              case template"{ ..$earlyStats } with ..$inits { $self => ..$stats }" =>
                val inits1 = inits :+ init"WithPayload"
                val template1 = template"{ ..$earlyStats } with ..$inits1 { $self => ..$stats }"
                val newParams = List(param"override val payload1: Int", param"override val payload2: Long")
                val paramss1: List[Term.ParamClause] = paramss match {
                  case Nil => List(newParams)
                  case params :: paramss1 => (params ++ newParams) :: paramss1
                }
                q"..$mods1 class $tname[..$tparams] ..$ctorMods (...$paramss1) $template1"
            }
          case _ => tree
        }
    
        super.apply(tree1)
      }
    }
    

    project/StringTransformer.scala

    trait StringTransformer {
      def apply(str: String): String
    }
    

    project/TreeTransformer.scala

    import scala.meta.*
    
    trait TreeTransformer extends Transformer with StringTransformer {
      override def apply(str: String): String = {
        val origTree = dialects.Scala3(str).parse[Source].get
        val newTree  = apply(origTree)
        newTree.toString
      }
    }
    

    common/src/main/scala/extendsWithPayload.scala

    import scala.annotation.StaticAnnotation
    
    class extendsWithPayload extends StaticAnnotation 
    

    before/src/main/scala/App.scala

    sealed trait WithPayload:
      def description: String
      def payload1: Int
      def payload2: Long
    
    @extendsWithPayload
    final case class Foo():
      override def description = "foo"
    
      @extendsWithPayload
      class Nested():
        override def description = "nested"
    
    @extendsWithPayload
    final case class Bar():
      override def description = "bar"
    
    final case class Baz()
    

    Execute sbt after/clean transform

    after/target/scala-3.2.2/src_managed/main/App.scala

    sealed trait WithPayload {
      def description: String
      def payload1: Int
      def payload2: Long
    }
    final case class Foo(override val payload1: Int, override val payload2: Long) extends WithPayload {
      override def description = "foo"
      class Nested(override val payload1: Int, override val payload2: Long) extends WithPayload { override def description = "nested" }
    }
    final case class Bar(override val payload1: Int, override val payload2: Long) extends WithPayload { override def description = "bar" }
    final case class Baz()
    

    #define in Java

    gcc -xc App.scala -E -P -o App1.scala
    

    App.scala

    #define EXTENDS_WITH_PAYLOAD ( \
        override val payload1: Int, \
        override val payload2: Long \
    ) extends WithPayload
    
    sealed trait WithPayload:
        def description: String
        def payload1: Int
        def payload2: Long
    
    final case class Foo EXTENDS_WITH_PAYLOAD:
        override def description = "foo"
    
    final case class Bar EXTENDS_WITH_PAYLOAD:
        override def description = "bar"
    

    App1.scala

    sealed trait WithPayload:
        def description: String
        def payload1: Int
        def payload2: Long
    final case class Foo ( override val payload1: Int, override val payload2: Long ) extends WithPayload:
        override def description = "foo"
    final case class Bar ( override val payload1: Int, override val payload2: Long ) extends WithPayload:
        override def description = "bar"
    

    Is it possible to using macro to modify the generated code of structural-typing instance invocation?

    scala.meta parent of parent of Defn.Object (Scalameta + Semanticdb)

    Quasiquotes in Scalafix