I'm looking for a way to find values (or their classes) that are captured by lambda (for serialization - something like Spark) in Scala 3 (I don't need Scala 2 support):
val a = "abc"
val f = () => a + "xyz"
serialize(f) // Should detect a / String as captured value
Doing this in runtime is kinda easy (iterating over f.getClass.getDeclaredFields
), but I would like to do it in compile time.
I was trying to inspect time of lambda in Macro, but it's detected as plain scala.Function0
without any interesting info.
I wonder if I can do some tree inspection, but I would really like to avoid that - I feel like I would have to copy compiler internals to catch all edge cases.
Try the following macro
import scala.quoted.*
inline def serialize(x: Any): Unit = ${serializeImpl('x)}
def serializeImpl(x: Expr[Any])(using Quotes): Expr[Unit] = {
import quotes.reflect.*
def owners(s: Symbol): List[Symbol] = s :: List.unfold(s)(s1 => Option.when(s1.maybeOwner != Symbol.noSymbol)((s1.maybeOwner, s1.maybeOwner)))
val symbol = x.asTerm.underlying.symbol
val rhs = symbol.tree match {
case ValDef(_, _, Some(rhs)) => rhs
}
val traverser = new TreeTraverser {
override def traverseTree(tree: Tree)(owner: Symbol): Unit = {
tree match {
case Ident(name) =>
val symbol1 = tree.symbol
val pos1 = symbol1.pos.get
println(s"identifier: $name, defined inside lambda: ${owners(symbol1).contains(symbol)}, defined in current file: ${pos1.sourceFile == SourceFile.current}")
case _ =>
}
super.traverseTree(tree)(owner)
}
}
traverser.traverseTree(rhs)(rhs.symbol)
'{()}
}
Usage:
object App1 {
val b = "bbb"
}
import App1.b
object App {
val a = "abc"
val f = () => { val x = "uvw"; a + b + x + "xyz"}
serialize(f)
}
//scalac: identifier: a, defined inside lambda: false, defined in current file: true
//scalac: identifier: b, defined inside lambda: false, defined in current file: false
//scalac: identifier: x, defined inside lambda: true, defined in current file: true
//scalac: identifier: $anonfun, defined inside lambda: true, defined in current file: true