scalametaprogrammingscala-macrosscala-3

Scala 3 : Finding functions with the given annotation


For Scala 3 macros, does anyone know of a way to find all functions with a given annotation?

For instance:

@fruit
def apple(): Int = ???

@fruit
def banana(): Int = ???

@fruit
def coconut(): Int = ???

@fruit
def durian(): Int = ???

def elephant(): Int = ???

@fruit
def fig(): Int = ???

I would want to find a list of apple, banana, coconut, durian, fig. They could be defined anywhere, but in my case they will all be in a single package.


Solution

  • This solution will extract all the definitions with some annotation from a given package. I will leverage also the compile-time reflection.

    This solution will extract all the definitions with some annotations from a given package. I will also leverage the compile-time reflection. So, To solve your problem, we need to divide it in:

    inline def findAllFunction[P, A <: ConstantAnnotation, R]: List[() => R] = 
        ${Implementation.myMacroImpl[P, A, R]()}
    

    The first point is straightforward. we could extract all the methods defined as:

    def methodsFromPackage(packageSymbol: Symbol): List[Symbol] =
      packageSymbol.declaredTypes
        .filter(_.isClassDef)
        .flatMap(_.declaredMethods)
    

    The second point is also quite easy. Symbol class has the method hasAnnotation that could be used in this case:

    def methodsAnnotatatedWith(
        methods: List[Symbol],
        annotation: Symbol
    ): List[Symbol] =
      methods.filter(_.hasAnnotation(annotation))
    

    The last point is a little bit challenging. Here we should construct the method call. So we need to create the AST that correspond to the method call. Inspired by this example, we can call definition using Apply. Select and This serve to select the correct method that will be called:

    def transformToFunctionApplication(methods: List[Symbol]): Expr[List[() => R]] =
      val appliedDef = methods
        .map(definition => Select(This(definition.owner), definition))
        .map(select => Apply(select, List.empty))
        .map(apply => '{ () => ${ apply.asExprOf[R] } })
      Expr.ofList(appliedDef)
    

    Here I used lamba call, if you want to return directly the value you should change the last two instructions:

    def transformToFunctionApplication(methods: List[Symbol]): Expr[List[R]] =
      val appliedDef = methods
        .map(definition => Select(This(definition.owner), definition))
        .map(select => Apply(select, List.empty))
        .map(apply => apply.asExprOf[R])
    
      Expr.ofList(appliedDef)
    

    To sum up, the all methods could be defined as:

    def myMacroImpl[P: Type, A: Type, R: Type]()(using
        Quotes
    ): Expr[List[() => R]] = {
      import quotes.reflect.*
      val annotation = TypeRepr.of[A].typeSymbol
      val moduleTarget = TypeRepr.of[P].typeSymbol
    
      def methodsFromPackage(packageSymbol: Symbol): List[Symbol] =
        packageSymbol.declaredTypes
          .filter(_.isClassDef)
          .flatMap(_.declaredMethods)
    
      def methodsAnnotatatedWith(
          methods: List[Symbol],
          annotation: Symbol
      ): List[Symbol] =
        methods.filter(_.hasAnnotation(annotation))
    
      def transformToFunctionApplication(
          methods: List[Symbol]
      ): Expr[List[() => R]] =
        val appliedDef = methods
          .map(definition => Select(This(definition.owner), definition))
          .map(select => Apply(select, List.empty))
          .map(apply => '{ () => ${ apply.asExprOf[R] } })
        Expr.ofList(appliedDef)
    
      val methods = methodsFromPackage(moduleTarget)
      val annotatedMethod = methodsAnnotatatedWith(methods, annotation)
      transformToFunctionApplication(annotatedMethod)
    }
    
    

    Finally, you can use the macro as:

    package org.tests
    import org.tests.Macros.fruit
    
    package foo {
      @fruit
      def check(): Int = 10
      @fruit
      def other(): Int = 11
    }
    
    
    @main def hello: Unit = 
      println("Hello world!")
      println(Macros.findAllFunction[org.tests.foo, fruit, Int].map(_.apply())) /// List(10, 11)
    

    Scastie