scalaapache-sparkmutableuser-defined-aggregate

Why Mutable map becomes immutable automatically in UserDefinedAggregateFunction(UDAF) in Spark


I am trying to define a UserDefinedAggregateFunction(UDAF) in Spark, which counts the number of occurrences for each unique values in a column of a group.

This is an example: Suppose I have a dataframe df like this,

+----+----+
|col1|col2|
+----+----+
|   a|  a1|
|   a|  a1|
|   a|  a2|
|   b|  b1|
|   b|  b2|
|   b|  b3|
|   b|  b1|
|   b|  b1|
+----+----+

I will have a UDAF DistinctValues

val func = new DistinctValues

Then I apply it to the dataframe df

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV"))

I am expecting to have something like this:

+----+--------------------------+
|col1|DV                        |
+----+--------------------------+
|   a|  Map(a1->2, a2->1)       |
|   b|  Map(b1->3, b2->1, b3->1)|
+----+--------------------------+

So I came out with a UDAF like this,

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.LongType
import Array._

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.mutable.Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp.put(str, c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0)
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1.put(k ,c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[scala.collection.mutable.Map[String, LongType]](0)
  }
}

Then I have this function on my dataframe,

val func = new DistinctValues
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV"))

It gave such error,

func: DistinctValues = $iwC$$iwC$DistinctValues@17f48a25
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
at $iwC$$iwC$DistinctValues.update(<console>:39)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
at org.apache.spark.scheduler.Task.run(Task.scala:89)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)

It looks like in the update(buffer: MutableAggregationBuffer, input: Row) method, the variable buffer is a immutable.Map, the program tired to cast it to mutable.Map,

But I used mutable.Map to initialize buffer variable in initialize(buffer: MutableAggregationBuffer, input:Row) method. Is it the same variable passed to update method? And also buffer is mutableAggregationBuffer, so it should be mutable, right?

Why my mutable.Map became immutable? Does anyone know what happened?

I really need a mutable Map in this function to complete the task. I know there is a workaround to create a mutable map from the immutable map, then update it. But I really want to know why the mutable one transforms to immutable one in the program automatically, it doesn't make sense to me.


Solution

  • Believe it is the MapType in your StructType. buffer therefore holds a Map, which would be immutable.

    You can convert it, but why don't you just leave it immutable and do this:

    mp = mp + (k -> c)
    

    to add an entry to the immutable Map?

    Working example below:

    class DistinctValues extends UserDefinedAggregateFunction {
      def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil)
    
      def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)
    
      def dataType: DataType =  MapType(StringType, LongType)
      def deterministic: Boolean = true
    
      def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = Map()
      }
    
      def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
        val str = input.getAs[String](0)
        var mp = buffer.getAs[Map[String, Long]](0)
        var c:Long = mp.getOrElse(str, 0)
        c = c + 1
        mp = mp  + (str -> c)
        buffer(0) = mp
      }
    
      def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
        var mp1 = buffer1.getAs[Map[String, Long]](0)
        var mp2 = buffer2.getAs[Map[String, Long]](0)
        mp2 foreach {
            case (k ,v) => {
                var c:Long = mp1.getOrElse(k, 0)
                c = c + v
                mp1 = mp1 + (k -> c)
            }
        }
        buffer1(0) = mp1
      }
    
      def evaluate(buffer: Row): Any = {
          buffer.getAs[Map[String, LongType]](0)
      }
    }