scalascala-macrosscala-3

How to reduce a `Term` that is reducible to a constant, in a scala 3 macro?


I want to reduce a Term instance, that I know is reducible to a constant, to an equivalent Term that can be used as the pattern of a CaseDef. Specifically, I need to implement the reduceTerm method of the following snippet:

object LibraryCode {
    trait DiscriminationCriteria[-S] {
        transparent inline def discriminator[P <: S]: Int
    }

    transparent inline def typeNameCorrespondingTo[S](discriminator: Int)(using mirror: Mirror.SumOf[S]): String =
        ${ typeNameCorrespondingToImpl[S, mirror.MirroredElemTypes]('discriminator) }

    private def typeNameCorrespondingToImpl[S: Type, Variants: Type](discriminatorExpr: Expr[Int])(using quotes: Quotes): Expr[String] = {
        import quotes.reflect.*

        Implicits.search(TypeRepr.of[DiscriminationCriteria[S]]) match {
            case isf: ImplicitSearchFailure =>
                report.errorAndAbort(isf.explanation)

            case iss: ImplicitSearchSuccess =>
                val discriminatorByVariant = Select.unique(iss.tree, "discriminator")

                def loop[RemainingVariants: Type]: List[CaseDef] = {
                    Type.of[RemainingVariants] match {
                        case '[headVariant *: tailVariants] =>
                            val headDiscriminator: Term = discriminatorByVariant.appliedToType(TypeRepr.of[headVariant])
                            val pattern: Term = reduceTerm(headDiscriminator)
                            val rhs: Expr[String] = '{ ${headDiscriminator.asExprOf[Int]}.toString }
                            CaseDef(pattern, None, rhs.asTerm) :: loop[tailVariants]
                            
                        case '[EmptyTuple] => Nil  
                    }
                }

                val cases = loop[Variants]
                Match(discriminatorExpr.asTerm, cases).asExprOf[String]
        }
    }
    
    private def reduceTerm(using quotes: Quotes)(term: quotes.reflect.Term): quotes.reflect.Term = {
        import quotes.reflect.*
        Literal(IntConstant(term.hashCode())) // dummy implementation
    }
}

With the current implementation of reduceTerm, a call to LibraryCode.typeNameCorrespondingTo in the following context...

object UserCode {
    sealed trait Animal

    case class Dog(dogField: Int) extends Animal

    case class Cat(catField: String) extends Animal
    
    object animalDc extends LibraryCode.DiscriminationCriteria[Animal] {
        override transparent inline def discriminator[P <: Animal]: Int =
            inline erasedValue[P] match {
                case _: Dog => 1
                case _: Cat => 2
            }
    }
    given animalDc.type = animalDc
    
    @main def runUserCode(): Unit = {
        val scrutinee = 1
        val typeName = LibraryCode.typeNameCorrespondingTo[Animal](scrutinee)
        println(typeName)
    }
}

... generates the following code:

scrutinee match {
  case 758192047 =>
    1.toString()
  case 432359492 =>
    2.toString()
}

As you can see, the expansion of headDiscriminator.asExprOf[Int] expression within the rhs block was transformed into a literal constant (the 1 and 2 in the generated code). That is exactly what I need but for the case patterns - that they have the same constants.

In short, I need to replace the current implementation of reduceTerm such that the generated code be:

scrutinee match {
  case 1 =>
    1.toString()
  case 2 =>
    2.toString()
}

Evidently, the headDiscriminator.asExprOf[Int] in the rhs block is eventually transformed to a literal constant at some phase of the macro expansion. I need to, either:

Is any of these possible?


Edit in response to @DmytroMitin answer:

It's not clear how you'd like to implement reduceTerm transforming the tree into the tree Literal(IntConstant(1)).

I don't expect an implementation that analyzes the syntax tree. That would be extremely complex, if even possible, since the expression depends on user code.

What I'm looking for is either:

As you can see in the code generated by the macro, the compiler is able to reduce the tree to a constant. But it does that later, during macro expansion. And I need the constant earlier, during macro execution.

Note that if, instead of comparing the scrutinee in the pattern it was compared in the guard (as shown below), the macro would do the job.

private def typeNameCorrespondingToImpl[S: Type, Variants: Type](scrutineeExpr: Expr[Int])(using quotes: Quotes): Expr[String] = {
    import quotes.reflect.*

    Implicits.search(TypeRepr.of[DiscriminationCriteria[S]]) match {
        case isf: ImplicitSearchFailure =>
            report.errorAndAbort(isf.explanation)

        case iss: ImplicitSearchSuccess =>
            val discriminatorByVariant = Select.unique(iss.tree, "discriminator")

            def loop[RemainingVariants: Type]: List[CaseDef] = {
                Type.of[RemainingVariants] match {
                    case '[headVariant *: tailVariants] =>
                        val headDiscriminator: Term = discriminatorByVariant.appliedToType(TypeRepr.of[headVariant])
                        val rhs: Expr[String] = Expr(Type.show[headVariant])
                        val caseDef = buildCaseDef(headDiscriminator, rhs)
                        caseDef :: loop[tailVariants]

                    case '[EmptyTuple] =>
                        Nil
                }
            }

            val cases = loop[Variants]
            Match(scrutineeExpr.asTerm, cases).asExprOf[String]
    }
}

private def buildCaseDef(using quotes: Quotes)(discriminator: quotes.reflect.Term, rhs: Expr[String]): quotes.reflect.CaseDef = {
    import quotes.reflect.*
    val bindSymbol = Symbol.newBind(Symbol.spliceOwner, "d", Flags.EmptyFlags, TypeRepr.of[Int])
    val pattern = Bind(bindSymbol, Wildcard())
    val guard = Select.overloaded(
        Ref(bindSymbol),
        "==",
        Nil,
        List(discriminator)
    )
    CaseDef(pattern, Some(guard), rhs.asTerm)
}

But I think that the code it generates is not ideal.

scrutinee match {
  case d if d == 1 => "Dog"
  case d if d == 2 => "Cat"
}

Because I suspect that the compiler is not intelligent enough to optimize a match-case construct that contains guards into a lookup table (or jump table).

I need to know if it is possible to implement the buildCaseDef method such that the generated code be:

scrutinee match {
  case 1 => "Dog"
  case 2 => "Cat"
}

An answer saying "It is not possible as of scala 3.7.0" is valid.


Solution

  • It's not clear how you'd like to implement reduceTerm transforming the tree

    TypeApply(
      Select(
        Ident(given_animalDc_type),
        discriminator
      ),
      List(
        TypeTree[TypeRef(ThisType(TypeRef(ThisType(TypeRef(NoPrefix,module class <empty>)),module class UserCode$)),class Dog)]
      )
    )
    

    into the tree Literal(IntConstant(1)), transforming the tree

    TypeApply(
      Select(
        Ident(given_animalDc_type),
        discriminator
      ),
      List(
        TypeTree[TypeRef(ThisType(TypeRef(ThisType(TypeRef(NoPrefix,module class <empty>)),module class UserCode$)),class Cat)]
      )
    )
    

    into the tree Literal(IntConstant(2)), without using Mirror once again.

    Try just to add counting parameter to loop

    def loop[RemainingVariants: Type](count: Int): List[CaseDef] = {
      Type.of[RemainingVariants] match {
        case '[headVariant *: tailVariants] =>
          val headDiscriminator: Term = discriminatorByVariant.appliedToType(TypeRepr.of[headVariant])
          val pattern: Term = Literal(IntConstant(count))
          val rhs: Expr[String] = '{ ${headDiscriminator.asExprOf[Int]}.toString }
          CaseDef(pattern, None, rhs.asTerm) :: loop[tailVariants](count + 1)
        
        case '[EmptyTuple] => Nil
      }
    }
    
    val cases = loop[Variants](1)
    

    By the way, you can try to derive the type class DiscriminationCriteria

    import scala.compiletime.constValue
    import scala.compiletime.ops.int.S
    
    type IndexOf[Elem, Tup <: Tuple] <: Int = Tup match
      case Elem *: _ => 0
      case _    *: t => S[IndexOf[Elem, t]]
    
    object DiscriminationCriteria:
      given [S](using mirror: Mirror.SumOf[S]): DiscriminationCriteria[S] =
        new DiscriminationCriteria[S]:
          override transparent inline def discriminator[P <: S]: Int =
            constValue/*valueOf*/[IndexOf[P, mirror.MirroredElemTypes]]
    

    Also I'm not sure that you need macros. You can try something like

    transparent inline def typeNameCorrespondingTo[S]: PartiallyApplied[S] =
      new PartiallyApplied[S]
    
    class PartiallyApplied[S] :
      transparent inline def apply[N <: Int & Singleton](inline discriminator: N)(using
        discriminationCriteria: DiscriminationCriteria[S],
        mirror: Mirror.SumOf[S],
      ): String =
        discriminationCriteria.discriminator[Tuple.Elem[mirror.MirroredElemTypes, N] & S].toString
    

    In Scala 2 macros there is c.eval. Scala 3 macros are more restrictive. In Scala 3 eval is absent and this is intended. Generally Expr[T] can't be transformed into T. Expr[T] and T exist in different contexts (at different times). (Surely you can run compiler manually and try to compile Expr[T] or its source code into T.) There's staging.run but it's forbidden in macros (unless compiler is patched):

    get annotations from class in scala 3 macros

    https://github.com/DmytroMitin/dotty-patched

    Try the following improved version of buildCaseDef

    private def buildCaseDef(using quotes: Quotes)(discriminator: quotes.reflect.Term, rhs: Expr[String]): quotes.reflect.CaseDef = {
      import quotes.reflect.*
      val bindSymbol = Symbol.newBind(Symbol.spliceOwner, "d", Flags.EmptyFlags, TypeRepr.of[Int])
      val pattern = Bind(bindSymbol, Wildcard())
      val guard = Select.overloaded(
        Ref(bindSymbol),
        "==",
        Nil,
        List(discriminator)
      )
    
      val literalOpt = guard.underlying match {
        case Apply(Select(_, _), List(literal@Literal(IntConstant(_)))) => Some(literal)
        case _ => None
      }
    
      literalOpt.map { literal =>
        CaseDef(literal, None, rhs.asTerm)
      }.getOrElse(
        CaseDef(pattern, Some(guard), rhs.asTerm)
      )
    }