scalatypesscala-macrosscala-3

Is there a simple Scala 3 example of how to use `quoted.Type` as replacement for `TypeTag`?


Martin Odersky said

Scala 3 has the quoted package, with quoted.Expr as a representation of expressions and quoted.Type as a representation of types. quoted.Type esentially replaces TypeTag. It does not have the same API but has similar functionality. It should be easier to use since it integrates well with quoted terms and pattern matching.

I knew how to use TypeTag in Scala 2:

def myFun[T](foo: T)(implicit tag: TypeTag[T]) =
  // and now I can do whatever I want with tag

but I have no idea how to do something similar with Type. I have a case where I need to keep around type information that was exactly the use case for TypeTag, but I can't find any examples of how to do this in Scala 3. (Well, people point to izumi-reflect and similar things, so I should say that I can't find any accessible--at least to me--examples.)

Can someone tell me (a) what's the type of the class that I should be using to store type info so that I can, for instance, use that info to correctly cast a value, and (b) how to use quoted.Type to get such a thing?

What I mean by (a) is that I have, say, an Iterator[Stuff[?]] with the equivalent of case class Stuff[T](value: T, tag: TypeTag[T]). When I get the next element from the iterator, I need to be able to cast it to a more specific type than Stuff[?], and that's possible because the tag somehow reifies the type and avoids type erasure, in the same way that (explicitly) saving a Class<T> beside a value would in Java.

To really boil this down, what I want is to be able to write something like

def foo: Foo[T] = someFun(arg1: X, arg2: T)

and have the type information from T be pushed from the left side of the equality to the right, so that the information about what T was when someFun was called is accessible inside the body of someFun. I think the answer is "macros can do it", but heck if I can figure out how.


Solution

  • Try the following approach. It's too hard to expose actual TypeRepr outside macros (because of dependent types and implicit Quotes), so I'm creating own runtime Type hierarchy mirroring the one from quotes.reflect.*

    case class TypeTag[T](tpe: my.Type)
    object TypeTag:
      inline given [T]: TypeTag[T] = TypeTag(getType[T])
    
    def typeTag[T: TypeTag]: TypeTag[T] = summon[TypeTag[T]]
    def typeOf[T: TypeTag]: my.Type = summon[TypeTag[T]].tpe
    
    object my:
      sealed trait Type
      case class ConstantType(constant: Constant) extends Type
      sealed trait NamedType extends Type:
        def qualifier: Type
        def name: String
      case class TermRef(qualifier: Type, name: String) extends NamedType
      case class TypeRef(qualifier: Type, name: String) extends NamedType //
      case class SuperType(thisTpe: Type, superTpe: Type) extends Type
      case class Refinement(parent: Type, name: String, info: Type) extends Type
      case class AppliedType(tycon: Type, args: List[Type]) extends Type
      case class AnnotatedType(underlying: Type, annot: Term) extends Type
      sealed trait AndOrType extends Type:
        def left: Type
        def right: Type
      case class AndType(left: Type, right: Type) extends AndOrType
      case class OrType(left: Type, right: Type) extends AndOrType
      case class MatchType(bound: Type, scrutinee: Type, cases: List[Type]) extends Type
      case class ByNameType(underlying: Type) extends Type
      case class ParamRef(binder: Type, paramNum: Int) extends Type //
      case class ThisType(tref: Type) extends Type //
      case class RecursiveThis(binder: RecursiveType) extends Type //
      case class RecursiveType(underlying: Type, recThis: RecursiveThis) extends Type //
      sealed trait LambdaType extends Type:
        def paramNames: List[String]
        def paramTypes: List[Type]
        def resType: Type
      sealed trait MethodOrPoly extends LambdaType
      case class MethodType(paramNames: List[String], paramTypes: List[Type], resType: Type) extends MethodOrPoly
      case class PolyType(paramNames: List[String], paramTypes: List[TypeBounds], resType: Type) extends MethodOrPoly
      case class TypeLambda(paramNames: List[String], paramTypes: List[TypeBounds], resType: Type) extends LambdaType
      case class MatchCase(pattern: Type, rhs: Type) extends Type
      case class TypeBounds(low: Type, hi: Type) extends Type
      case object NoPrefix extends Type
    
      sealed trait Term
      case class New(tpe: Type/*tpt: TypeTree*/) extends Term
    
      sealed trait Constant
      case class BooleanConstant(b: Boolean) extends Constant
      case class ByteConstant(b: Byte) extends Constant
      case class ShortConstant(s: Short) extends Constant
      case class IntConstant(i: Int) extends Constant
      case class LongConstant(l: Long) extends Constant
      case class FloatConstant(f: Float) extends Constant
      case class DoubleConstant(d: Double) extends Constant
      case class CharConstant(c: Char) extends Constant
      case class StringConstant(s: String) extends Constant
      case object UnitConstant extends Constant
      case object NullConstant extends Constant
      case class ClassOfConstant(tpe: Type) extends Constant
    
    import scala.quoted.*
    
    inline def getType[T]: my.Type = ${getTypeImpl[T]}
    
    def getTypeImpl[T: Type](using Quotes): Expr[my.Type] =
      import quotes.reflect.*
    
      def mkConstant(constant: Constant): Expr[my.Constant] = constant match
        case BooleanConstant(b) => '{my.BooleanConstant(${ Expr(b) })}
        case ByteConstant(b) => '{my.ByteConstant(${ Expr(b) })}
        case ShortConstant(s) => '{my.ShortConstant(${ Expr(s) })}
        case IntConstant(i) => '{my.IntConstant(${ Expr(i) })}
        case LongConstant(l) => '{my.LongConstant(${ Expr(l) })}
        case FloatConstant(f) => '{my.FloatConstant(${ Expr(f) })}
        case DoubleConstant(d) => '{my.DoubleConstant(${ Expr(d) })}
        case CharConstant(c) => '{my.CharConstant(${ Expr(c) })}
        case StringConstant(s) => '{my.StringConstant(${ Expr(s) })}
        case UnitConstant() => '{my.UnitConstant}
        case NullConstant() => '{my.NullConstant}
        case ClassOfConstant(tpe) => '{my.ClassOfConstant(${ mkType(tpe) })}
    
      def mkType(tpe: TypeRepr): Expr[my.Type] = tpe match
        case ConstantType(constant) => '{ my.ConstantType(${mkConstant(constant)}) }
        case TermRef(qualifier, name) => '{my.TermRef(${mkType(qualifier)}, ${Expr(name)})}
        case TypeRef(qualifier, name) => '{my.TypeRef(${mkType(qualifier)}, ${Expr(name)})}
        case SuperType(thisTpe, superTpe) => '{my.SuperType(${mkType(thisTpe)}, ${mkType(superTpe)})}
        case Refinement(parent, name, info) => '{my.Refinement(${mkType(parent)}, ${Expr(name)}, ${mkType(info)})}
        case AppliedType(tycon, args) => '{my.AppliedType(${mkType(tycon)}, ${Expr.ofList(args.map(mkType))})}
        case AnnotatedType(underlying, annot) => '{my.AnnotatedType(${mkType(underlying)}, ${mkTerm(annot)})}
        case AndType(left, right) => '{my.AndType(${mkType(left)}, ${mkType(right)})}
        case OrType(left, right) => '{my.OrType(${mkType(left)}, ${mkType(right)})}
        case MatchType(bound, scrutinee, cases) => '{my.MatchType(${mkType(bound)}, ${mkType(scrutinee)}, ${Expr.ofList(cases.map(mkType))})}
        case ByNameType(underlying) => '{my.ByNameType(${mkType(underlying)})}
        case ParamRef(binder, paramNum) => '{my.ParamRef(${mkType(binder)}, ${Expr(paramNum)})}
        case ThisType(tref) => '{my.ThisType(${mkType(tref)})}
        case RecursiveThis(binder) => '{my.RecursiveThis(${mkRecursiveType(binder)})}
        case MethodType(paramNames, paramTypes, resType) => '{my.MethodType(${Expr(paramNames)}, ${Expr.ofList(paramTypes.map(mkType))}, ${mkType(resType)})}
        case PolyType(paramNames, paramTypes, resType) => '{my.PolyType(${Expr(paramNames)}, ${Expr.ofList(paramTypes.map(mkTypeBounds))}, ${mkType(resType)})}
        case TypeLambda(paramNames, paramTypes, resType) => '{my.TypeLambda(${Expr(paramNames)}, ${Expr.ofList(paramTypes.map(mkTypeBounds))}, ${mkType(resType)})}
        case MatchCase(pattern, rhs) => '{my.MatchCase(${mkType(pattern)}, ${mkType(rhs)})}
        case TypeBounds(low, hi) => '{my.TypeBounds(${mkType(low)}, ${mkType(hi)})}
        case NoPrefix() => '{my.NoPrefix}
    
      def mkTerm(term: Term): Expr[my.Term] = term match
        case New(tpt) => '{my.New(${mkType(tpt.tpe)})}
    
      def mkRecursiveThis(recThis: RecursiveThis): Expr[my.RecursiveThis] = recThis match
        case RecursiveThis(binder) => '{my.RecursiveThis(${mkRecursiveType(binder)})}
      def mkRecursiveType(recTpe: RecursiveType): Expr[my.RecursiveType] =
        '{my.RecursiveType(${mkType(recTpe.underlying)}, ${mkRecursiveThis(recTpe.recThis)})}
      def mkTypeBounds(typeBounds: TypeBounds): Expr[my.TypeBounds] = typeBounds match
        case TypeBounds(lo, hi) => '{my.TypeBounds(${mkType(lo)}, ${mkType(hi)})}
    
      mkType(TypeRepr.of[T])
    

    Usage:

    def myFun[T: TypeTag](foo: T) = println(typeOf[T])
    
    myFun(1)
    //TypeRef(ThisType(TypeRef(NoPrefix,scala)),Int)
    myFun("a")
    //TypeRef(ThisType(TypeRef(NoPrefix,lang)),String)
    myFun((_: Int).toString)
    //AppliedType(TypeRef(ThisType(TypeRef(NoPrefix,scala)),Function1),List(TypeRef(TermRef(ThisType(TypeRef(NoPrefix,<root>)),scala),Int), TypeRef(ThisType(TypeRef(NoPrefix,lang)),String)))
    myFun(Map("a" -> 1, "b" -> 2))
    //AppliedType(TypeRef(ThisType(TypeRef(NoPrefix,immutable)),Map),List(TypeRef(ThisType(TypeRef(NoPrefix,lang)),String), TypeRef(ThisType(TypeRef(NoPrefix,scala)),Int)))