scalaapache-sparkapache-spark-sqlapache-spark-datasetapache-spark-encoders

Column type inferred as binary with typed UDAF


I'm trying to implement a typed UDAF that returns a complex type. Somehow Spark cannot infer the type of a result column and makes it binary putting the serialized data there. Here's a minimal example that reproduces the problem

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{SparkSession, Encoder, Encoders}

case class Data(key: Int)

class NoopAgg[I] extends Aggregator[I, Map[String, Int], Map[String, Int]] {
    override def zero: Map[String, Int] = Map.empty[String, Int]

    override def reduce(b: Map[String, Int], a: I): Map[String, Int] = b

    override def merge(b1: Map[String, Int], b2: Map[String, Int]): Map[String, Int] = b1

    override def finish(reduction: Map[String, Int]): Map[String, Int] = reduction

    override def bufferEncoder: Encoder[Map[String, Int]] = Encoders.kryo[Map[String, Int]]

    override def outputEncoder: Encoder[Map[String, Int]] = Encoders.kryo[Map[String, Int]]
}

object Question {
  def main(args: Array[String]): Unit = {
      val spark = SparkSession.builder().master("local").getOrCreate()

      val sc = spark.sparkContext

      import spark.implicits._

      val ds = sc.parallelize((1 to 10).map(i => Data(i))).toDS()

      val noop = new NoopAgg[Data]().toColumn

      val result = ds.groupByKey(_.key).agg(noop.as("my_sum").as[Map[String, Int]])

      result.printSchema()
  }
}

It prints

root
 |-- value: integer (nullable = false)
 |-- my_sum: binary (nullable = true)

Solution

  • There is no inference here at all. Instead you get more or less what you ask for. Specifically the mistake is here:

    override def outputEncoder: Encoder[Map[String, Int]] = Encoders.kryo[Map[String, Int]]
    

    Encoders.kryo means that you apply general purpose serialization and return a binary blob. The misleading part is .as[Map[String, Int]] - contrary to what one might expect it is not statically type checked. To make it even worse it is not even proactively validated by the query planner, and the runtime exception is throw, only when the result is evaluated.

    result.first
    
    org.apache.spark.sql.AnalysisException: cannot resolve '`my_sum`' due to data type mismatch: cannot cast binary to map<string,int>;
      at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
      at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:115)
    ...
    

    You should provide specific Encoder instead, either explicitly:

    import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder  
    
    def outputEncoder: Encoder[Map[String, Int]] = ExpressionEncoder()
    

    or implicitly

    class NoopAgg[I](implicit val enc: Encoder[Map[String, Int]]) extends Aggregator[I, Map[String, Int], Map[String, Int]] {
      ...
      override def outputEncoder: Encoder[Map[String, Int]] = enc
    }
    

    As a side effect it will make as[Map[String, Int]] obsolete, as the return type of the Aggregator is already known.