The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.
Say the function impl() returns an anonymous function:
trait Base {}
class A extends Base{
def impl(): Function1[Int, String] = new Function1[Int, String] {
def apply(x: Int): String = "ab" + x.toString
}
}
val classes = reflections.getSubTypesOf(classOf[Base]).toSet[Class[_ <: Base]].toList
and I obtain the anonymous function in another place:
val clazz = classes(0)
val instance = clazz.newInstance()
val impl = clazz.getDeclaredMethod("impl").invoke(instance)
Now, impl holds the anonymous function but I do not know its signature, and I'd like to ask whether we can convert it into a correct function instance:
impl.asInstanceOf[Function1[Int, String]] // How to determine the function signature of the anonymous function, in this case Function1[Int, String]?
Since scala does not support generic function, I first consider getting the runtime type of the function:
import scala.reflect.runtime.universe.{TypeTag, typeTag}
def getTypeTag[T: TypeTag](obj: T) = typeTag[T]
val typeList = getTypeTag(impl).tpe.typeArgs
It will return List(Int, String), but I fail to recognize the correct function template via reflection.
Update: if the classes are defined as follows:
trait Base {}
class A extends Base{
def impl(x: Int): String = {
"ab" + x.toString
}
}
where impl is the function itself and we do not know its function signature, can the impl function still be registered?
The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.
Normally you register a UDF as follows
import org.apache.spark.sql.SparkSession
object App {
val spark = SparkSession.builder
.master("local")
.appName("Spark app")
.getOrCreate()
def impl(): Int => String = x => "ab" + x.toString
spark.udf.register("foo", impl())
def main(args: Array[String]): Unit = {
spark.sql("""SELECT foo(10)""").show()
//+-------+
//|foo(10)|
//+-------+
//| ab10|
//+-------+
}
}
The signature of register
is
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction
aka
def register[RT, A1](name: String, func: Function1[A1, RT])(implicit
ttag: TypeTag[RT],
ttag1: TypeTag[A1]
): UserDefinedFunction
What TypeTag
normally does is persisting a type information from compile time to runtime.
So in order to call register
you either have to know types at compile time or have to know how to construct type tags at runtime.
If you don't have access to how impl()
is constructed at runtime and you don't have (at least at runtime) the information about types/type tags at all then unfortunately this type information is irreversibly lost because of the type erasure (Function1[Int, String]
is just Function1[_,_]
at runtime)
def impl(): Any = (x: Int) => "ab" + x.toString
But it's possible that you have access to how impl()
is constructed at runtime and you know (at least at runtime) the information about types/type tags. So I assume that you don't have types Int
, String
statically and you can't call typeTag[Int]
, typeTag[String]
(as I do below) but you have somehow runtime objects of Type
/TypeTag
import org.apache.spark.sql.catalyst.ScalaReflection.universe._
def impl(): Any = (x: Int) => "ab" + x.toString
val ttag1 = typeTag[Int] // actual definition is probably different
val ttag = typeTag[String] // actual definition is probably different
In such case you can call register
resolving implicits explicitly
spark.udf.register("foo", impl().asInstanceOf[Function1[_,_]])(ttag.asInstanceOf[TypeTag[_]], ttag1.asInstanceOf[TypeTag[_]])
Well, this doesn't compile because of existential types but you can trick the compiler
type A
type B
spark.udf.register("foo", impl().asInstanceOf[A => B])(ttag.asInstanceOf[TypeTag[B]], ttag1.asInstanceOf[TypeTag[A]])
https://gist.github.com/DmytroMitin/0b3660d646f74fb109665bad41b3ae9f
Alternatively you can use runtime compilation (creating a new compile time inside the runtime)
import org.apache.spark.sql.catalyst.ScalaReflection
import ScalaReflection.universe._
import scala.tools.reflect.ToolBox // libraryDependencies += scalaOrganization.value % "scala-compiler" % scalaVersion.value
val rm = ScalaReflection.mirror
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", App.impl().asInstanceOf[$ttag1 => $ttag])""")
https://gist.github.com/DmytroMitin/5b5dd4d7db0d0eebb51dd8c16735e0fb
You should provide some code how you construct impl()
and we'll see whether it's possible to restore the types.
Spark registered a Scala object all of the methods as a UDF
scala cast object based on reflection symbol
Update. After you get val impl = clazz.getDeclaredMethod("impl").invoke(instance)
it's too late to restore function types (you can check that typeList
is empty). Where function type (or type tag) should be captured is somewhere not too far from class A
, maybe inside A
or outside A
but when Int
, String
are not lost yet. What TypeTag
can do is persisting type information from compile time to runtime, it can't restore type information at runtime if it's lost.
import org.apache.spark.sql.catalyst.ScalaReflection
import ScalaReflection.universe._
import org.apache.spark.sql.SparkSession
import org.reflections.Reflections
import scala.jdk.CollectionConverters._
import scala.reflect.api
object App {
def getType[T: TypeTag](obj: T) = typeOf[T]
trait Base
class A extends Base {
def impl(): Int => String = x => "ab" + x.toString
// NotSerializableException
//def impl(): Function1[Int, String] = new Function1[Int, String] {
// def apply(x: Int): String = "ab" + x.toString
//}
val tpe = getType(impl())
}
val reflections = new Reflections()
val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList
val clazz = classes(0)
val instance = clazz.newInstance()
val impl = clazz.getDeclaredMethod("impl").invoke(instance)
val functionType = clazz.getDeclaredMethod("tpe").invoke(instance).asInstanceOf[Type]
val List(argType, returnType) = functionType.typeArgs
val spark = SparkSession.builder()
.master("local")
.appName("Spark app")
.getOrCreate()
val rm = ScalaReflection.mirror
// (*)
def typeToTypeTag[T](tpe: Type): TypeTag[T] =
TypeTag(rm, new api.TypeCreator {
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
tpe.asInstanceOf[U#Type]
})
// type X
// type Y
// spark.udf.register("foo", impl.asInstanceOf[X => Y])(
// typeToTypeTag[Y](returnType),
// typeToTypeTag[X](argType)
// )
impl match {
case impl: Function1[x, y] => spark.udf.register("foo", impl)(
typeToTypeTag[y](returnType),
typeToTypeTag[x](argType)
)
}
def main(args: Array[String]): Unit = {
spark.sql("""SELECT foo(10)""").show()
}
}
https://gist.github.com/DmytroMitin/2ebfae922f8a467d01b6ef18c8b8e5ad
(*) Get a TypeTag from a Type?
Now spark.sql("""SELECT foo(10)""").show()
throws java.io.NotSerializableException
but I guess it's not related to reflection.
Alternatively you can use runtime compilation (instead of manual resolution of implicits and construction of type tags from types)
import scala.tools.reflect.ToolBox
val rm = ScalaReflection.mirror
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", App.impl.asInstanceOf[$functionType])""")
https://gist.github.com/DmytroMitin/ba469faeca2230890845e1532b36e2a1
One more option is to request the return type of method impl()
as soon as we get class A
(outside A
)
class A extends Base {
def impl(): Int => String = x => "ab" + x.toString
}
// ...
val functionType = rm.classSymbol(clazz).typeSignature.decl(TermName("impl")).asMethod.returnType
val List(argType, returnType) = functionType.typeArgs
https://gist.github.com/DmytroMitin/3bd2c19d158f8241a80952c397ee5e09
Update 2. If the methods are defined as follows:
class A extends Base{
def impl(x: Int): String = {
"ab" + x.toString
}
}
then runtime compilation normally should be
val rm = ScalaReflection.mirror
val classSymbol = rm.classSymbol(clazz)
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).$methodSymbol(_))""")
or
tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).impl(_))""")
but now with Spark it produces ClassCastException: cannot assign instance of java.lang.invoke.SerializedLambda to field org.apache.spark.sql.catalyst.expressions.ScalaUDF.f of type scala.Function1 in instance of org.apache.spark.sql.catalyst.expressions.ScalaUDF
similarly to Spark registered a Scala object all of the methods as a UDF
https://gist.github.com/DmytroMitin/b0f110f4cf15e2dfd4add70f7124a7b6
But ordinary Scala runtime reflection seems to work
val rm = ScalaReflection.mirror
val classSymbol = rm.classSymbol(clazz)
val methodSymbol = classSymbol.typeSignature.decl(TermName("impl")).asMethod
val returnType = methodSymbol.returnType
val argType = methodSymbol.paramLists.head.head.typeSignature
val constructorSymbol = classSymbol.typeSignature.decl(termNames.CONSTRUCTOR).asMethod
val instance = rm.reflectClass(classSymbol).reflectConstructor(constructorSymbol)()
val impl: Any => Any = rm.reflect(instance).reflectMethod(methodSymbol)(_)
def typeToTypeTag[T](tpe: Type): TypeTag[T] =
TypeTag(rm, new api.TypeCreator {
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
tpe.asInstanceOf[U#Type]
})
impl match {
case impl: Function1[x, y] => spark.udf.register("foo", impl)(
typeToTypeTag[y](returnType),
typeToTypeTag[x](argType)
)
}
https://gist.github.com/DmytroMitin/763751096fe9cdb2e0d18ae4b9290a54
Update 3. One more approach is to use compile-time reflection (macros) rather than runtime reflection if you have enough information at compile time (e.g. if all the classes are known at compile time)
import scala.collection.mutable
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
object Macros {
def registerMethod[A](): Unit = macro registerMethodImpl[A]
def registerMethodImpl[A: c.WeakTypeTag](c: blackbox.Context)(): c.Tree = {
import c.universe._
val A = weakTypeOf[A]
var children = mutable.Seq[Type]()
val traverser = new Traverser {
override def traverse(tree: Tree): Unit = {
tree match {
case _: ClassDef =>
val tpe = tree.symbol.asClass.toType
if (tpe <:< A && !(tpe =:= A)) children :+= tpe
case _ =>
}
super.traverse(tree)
}
}
c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))
val calls = children.map(tpe =>
q"""spark.udf.register("foo", (new $tpe).impl(_))"""
)
q"..$calls"
}
}
// in a different subproject
import org.apache.spark.sql.SparkSession
object App {
trait Base
class A extends Base {
def impl(x: Int): String = "ab" + x.toString
}
val spark = SparkSession.builder()
.master("local")
.appName("Spark app")
.getOrCreate()
Macros.registerMethod[Base]()
def main(args: Array[String]): Unit = {
spark.sql("""SELECT foo(10)""").show()
}
}
https://gist.github.com/DmytroMitin/6623f1f900330f8341f209e1347a0007
Shapeless - How to derive LabelledGeneric for Coproduct (KnownSubclasses
)
Update 4. If we replace val clazz = classes.head
with classes.foreach(clazz => ...
then issues with NotSerializableException
can be fixed with inlining
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
object Macros {
def registerMethod(clazz: Class[_]): Unit = macro registerMethodImpl
def registerMethodImpl(c: blackbox.Context)(clazz: c.Tree): c.Tree = {
import c.universe._
val ScalaReflection = q"_root_.org.apache.spark.sql.catalyst.ScalaReflection"
val rm = q"$ScalaReflection.mirror"
val ru = q"$ScalaReflection.universe"
val classSymbol = q"$rm.classSymbol($clazz)"
val methodSymbol = q"""$classSymbol.typeSignature.decl($ru.TermName("impl")).asMethod"""
val returnType = q"$methodSymbol.returnType"
val argType = q"$methodSymbol.paramLists.head.head.typeSignature"
val constructorSymbol = q"$classSymbol.typeSignature.decl($ru.termNames.CONSTRUCTOR).asMethod"
val instance = q"$rm.reflectClass($classSymbol).reflectConstructor($constructorSymbol).apply()"
val impl1 = q"(x: Any) => $rm.reflect($instance).reflectMethod($methodSymbol).apply(x)"
val api = q"_root_.scala.reflect.api"
def typeToTypeTag(T: Tree, tpe: Tree): Tree =
q"""
$ru.TypeTag[$T]($rm, new $api.TypeCreator {
override def apply[U <: $api.Universe with _root_.scala.Singleton](m: $api.Mirror[U]) =
$tpe.asInstanceOf[U#Type]
})
"""
val impl2 = TermName(c.freshName("impl2"))
val x = TypeName(c.freshName("x"))
val y = TypeName(c.freshName("y"))
q"""
$impl1 match {
case $impl2: _root_.scala.Function1[$x, $y] => spark.udf.register("foo", $impl2)(
${typeToTypeTag(tq"$y", returnType)},
${typeToTypeTag(tq"$x", argType)}
)
}
"""
}
}
// in a different subproject
import org.apache.spark.sql.SparkSession
import org.reflections.Reflections
import scala.jdk.CollectionConverters._
trait Base
class A extends Base /*with Serializable*/ {
def impl(x: Int): String = "ab" + x.toString
}
object App {
val spark: SparkSession = SparkSession.builder()
.master("local")
.appName("Spark app")
.getOrCreate()
val reflections = new Reflections()
val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList
classes.foreach(clazz =>
Macros.registerMethod(clazz)
)
def main(args: Array[String]): Unit = {
spark.sql("""SELECT foo(10)""").show()
}
}
https://gist.github.com/DmytroMitin/c926158a9ff94a6539097c603bbedf6a