scalamacrosannotationsscala-macrosscala-reflect

Scala whitebox macro how to check if class fields are of type of a case class


I am trying to generate a case class from a given case class that strips of Option from the fields. It needs to this recursively, so if the field itself is a case class then it must remove Option from it's fields as well.

So far I managed to it for where no fields are not a case class. But for recursion I need to get the ClassTag for the field if it's a case class. But I have no idea how I can do this. Seems like all I can access is the syntax tree before type check (I guess makes sense considering the final source code isn't formed yet). But I am wondering if it's possible to achieve this in some way.

Here is my code and the missing part as comment.

import scala.annotation.StaticAnnotation
import scala.collection.mutable
import scala.reflect.macros.blackbox.Context

import scala.language.experimental.macros
import scala.annotation.compileTimeOnly

class RemoveOptionFromFields extends StaticAnnotation {
  def macroTransform(annottees: Any*): Any = macro RemoveOptionFromFields.impl
}

object RemoveOptionFromFields {
  def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
    import c.universe._
 
    def modifiedClass(classDecl: ClassDef, compDeclOpt: Option[ModuleDef]) = {
      val result = classDecl match {
        case q"case class $className(..$fields) extends ..$parents { ..$body }" =>
          val fieldsWithoutOption = fields.map {
            case ValDef(mods, name, tpt, rhs) =>
              tpt.children match {
                case List(first, second) if first.toString() == "Option" =>
                  // Check if `second` is a case class?
                  // Get it's fields if so
                  val innerType = tpt.children(1)
                  ValDef(mods, name, innerType, rhs)
                case _ => 
                  ValDef(mods, name, tpt, rhs)
              }
          }

          val withOptionRemovedFromFieldsClassDecl = q"case class WithOptionRemovedFromFields(..$fieldsWithoutOption)"

          val newCompanionDecl = compDeclOpt.fold(
            q"""
            object ${className.toTermName} {
              $withOptionRemovedFromFieldsClassDecl
            }
            """
          ) {
            compDecl =>
              val q"object $obj extends ..$bases { ..$body }" = compDecl
              q"""
              object $obj extends ..$bases {
                ..$body

                $withOptionRemovedFromFieldsClassDecl
              }
              """
          }

          q"""
            $classDecl
            $newCompanionDecl
          """
      }
      c.Expr[Any](result)
    }
 
    annottees.map(_.tree) match {
      case (classDecl: ClassDef) :: Nil => modifiedClass(classDecl, None)
      case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => modifiedClass(classDecl, Some(compDecl))
      case _ => c.abort(c.enclosingPosition, "This annotation only supports classes")
    }
  }
}

Solution

  • Not sure I understand what kind of recursion you need. Suppose we have two case classes: the 1st annotated (referring the 2nd) and the 2nd not annotated

    @RemoveOptionFromFields  
    case class MyClass1(mc: Option[MyClass2])  
    
    case class MyClass2(i: Option[Int])
    

    What should be the result?

    Currently the annotation transforms into

    case class MyClass1(mc: Option[MyClass2]) 
    object MyClass1 {
      case class WithOptionRemovedFromFields(mc: Class2)
    } 
    
    case class MyClass2(i: Option[Int])
    

    if the field itself is a case class then it must remove Option from it's fields as well.

    Macro annotation can rewrite only class and its companion, it can't rewrite different classes. In my example with 2 classes the annotation can modify MyClass1 and its companion but can't rewrite MyClass2 or its companion. For that MyClass2 should be annotated itself.

    In a scope macro annotations are expanded before type checking of this scope. So upon rewriting trees are untyped. If you need some trees to be typed (so that you can find their symbols) you can use c.typecheck

    Scala macros: What is the difference between typed (aka typechecked) and untyped Trees

    To check that some class is a case class you can use symbol.isClass && symbol.asClass.isCaseClass

    How to check if some T is a case class at compile time in Scala?

    Hardly you need ClassTags.

    One more complication is when MyClass1 and MyClass2 are in the same scope

    @RemoveOptionFromFields 
    case class MyClass1(mc: Option[MyClass2])  
    
    case class MyClass2(i: Option[Int])
    

    Then upon expansion of macro annotation for MyClass1 the scope isn't typechecked yet, so it's impossible to typecheck the tree of field definition mc: Option[MyClass2] (class MyClass2 is not known yet). If the classes are in different scopes it's ok

    {
      @RemoveOptionFromFields 
      case class MyClass1(mc: Option[MyClass2])
    }  
    
    case class MyClass2(i: Option[Int])
    

    This is modified version of your code (I'm just printing the fields of the second class)

    import scala.annotation.StaticAnnotation
    import scala.reflect.macros.blackbox
    import scala.language.experimental.macros
    import scala.annotation.compileTimeOnly
    
    @compileTimeOnly("enable macro annotations")
    class RemoveOptionFromFields extends StaticAnnotation {
      def macroTransform(annottees: Any*): Any = macro RemoveOptionFromFields.impl
    }
    
    object RemoveOptionFromFields {
      def impl(c: blackbox.Context)(annottees: c.Tree*): c.Tree = {
        import c.universe._
    
        def modifiedClass(classDecl: ClassDef, compDeclOpt: Option[ModuleDef]) = {
          classDecl match {
            case q"$mods class $className[..$tparams] $ctorMods(..$fields) extends { ..$earlydefns } with ..$parents { $self => ..$body }"
              if mods.hasFlag(Flag.CASE) =>
              val fieldsWithoutOption = fields.map {
                case field@q"$mods val $name: $tpt = $rhs" =>
                  tpt match {
                    case tq"$first[..${List(second)}]" =>
                      val firstType = c.typecheck(tq"$first", mode = c.TYPEmode, silent = true) match {
                        case EmptyTree => println(s"can't typecheck $first while expanding @RemoveOptionFromFields for $className"); NoType
                        case t => t.tpe
                      }
                      if (firstType <:< typeOf[Option[_]].typeConstructor) {
                        val secondSymbol = c.typecheck(tq"$second", mode = c.TYPEmode, silent = true) match {
                          case EmptyTree => println(s"can't typecheck $second while expanding @RemoveOptionFromFields for $className"); NoSymbol
                          case t => t.symbol
                        }
                        if (secondSymbol.isClass && secondSymbol.asClass.isCaseClass) {
                          val secondClassFields = secondSymbol.typeSignature.decls.toList.filter(s => s.isMethod && s.asMethod.isCaseAccessor)
                          secondClassFields.foreach(s =>
                            c.typecheck(q"$s", silent = true) match {
                              case EmptyTree => println(s"can't typecheck $s while expanding @RemoveOptionFromFields for $className")
                              case t => println(s"field ${t.symbol} of type ${t.tpe}, subtype of Option: ${t.tpe <:< typeOf[Option[_]]}")
                            }
                          )
                        }
                        q"$mods val $name: $second = $rhs"
                      } else field
                    case _ =>
                      field
                  }
              }
    
              val withOptionRemovedFromFieldsClassDecl = q"case class WithOptionRemovedFromFields(..$fieldsWithoutOption)"
    
              val newCompanionDecl = compDeclOpt.fold(
                q"""
                  object ${className.toTermName} {
                    $withOptionRemovedFromFieldsClassDecl
                  }
                """
              ) {
                compDecl =>
                  val q"$mods object $obj extends { ..$earlydefns } with ..$bases { $self => ..$body }" = compDecl
                  q"""
                    $mods object $obj extends { ..$earlydefns } with ..$bases { $self =>
                      ..$body
    
                      $withOptionRemovedFromFieldsClassDecl
                    }
                  """
              }
    
              q"""
                $classDecl
                $newCompanionDecl
              """
          }
        }
    
        annottees match {
          case (classDecl: ClassDef) :: Nil => modifiedClass(classDecl, None)
          case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => modifiedClass(classDecl, Some(compDecl))
          case _ => c.abort(c.enclosingPosition, "This annotation only supports classes")
        }
      }
    }