I'm trying to implement a typed UDAF that returns a complex type. Somehow Spark cannot infer the type of a result column and makes it binary
putting the serialized data there. Here's a minimal example that reproduces the problem
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{SparkSession, Encoder, Encoders}
case class Data(key: Int)
class NoopAgg[I] extends Aggregator[I, Map[String, Int], Map[String, Int]] {
override def zero: Map[String, Int] = Map.empty[String, Int]
override def reduce(b: Map[String, Int], a: I): Map[String, Int] = b
override def merge(b1: Map[String, Int], b2: Map[String, Int]): Map[String, Int] = b1
override def finish(reduction: Map[String, Int]): Map[String, Int] = reduction
override def bufferEncoder: Encoder[Map[String, Int]] = Encoders.kryo[Map[String, Int]]
override def outputEncoder: Encoder[Map[String, Int]] = Encoders.kryo[Map[String, Int]]
}
object Question {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val sc = spark.sparkContext
import spark.implicits._
val ds = sc.parallelize((1 to 10).map(i => Data(i))).toDS()
val noop = new NoopAgg[Data]().toColumn
val result = ds.groupByKey(_.key).agg(noop.as("my_sum").as[Map[String, Int]])
result.printSchema()
}
}
It prints
root
|-- value: integer (nullable = false)
|-- my_sum: binary (nullable = true)
There is no inference here at all. Instead you get more or less what you ask for. Specifically the mistake is here:
override def outputEncoder: Encoder[Map[String, Int]] = Encoders.kryo[Map[String, Int]]
Encoders.kryo
means that you apply general purpose serialization and return a binary blob. The misleading part is .as[Map[String, Int]]
- contrary to what one might expect it is not statically type checked. To make it even worse it is not even proactively validated by the query planner, and the runtime exception is throw, only when the result
is evaluated.
result.first
org.apache.spark.sql.AnalysisException: cannot resolve '`my_sum`' due to data type mismatch: cannot cast binary to map<string,int>;
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:115)
...
You should provide specific Encoder
instead, either explicitly:
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
def outputEncoder: Encoder[Map[String, Int]] = ExpressionEncoder()
or implicitly
class NoopAgg[I](implicit val enc: Encoder[Map[String, Int]]) extends Aggregator[I, Map[String, Int], Map[String, Int]] {
...
override def outputEncoder: Encoder[Map[String, Int]] = enc
}
As a side effect it will make as[Map[String, Int]]
obsolete, as the return type of the Aggregator
is already known.