scalaapache-sparkudf

Spark UDF with varargs


Is it an only option to list all the arguments up to 22 as shown in documentation?

https://spark.apache.org/docs/1.5.0/api/scala/index.html#org.apache.spark.sql.UDFRegistration

Anyone figured out how to do something similar to this?

sc.udf.register("func", (s: String*) => s......

(writing custom concat function that skips nulls, had to 2 arguments at the time)

Thanks


Solution

  • UDFs don't support varargs* but you can pass an arbitrary number of columns wrapped using an array function:

    import org.apache.spark.sql.functions.{udf, array, lit}
    
    val myConcatFunc = (xs: Seq[Any], sep: String) => 
      xs.filter(_ != null).mkString(sep)
    
    val myConcat = udf(myConcatFunc)
    

    An example usage:

    val  df = sc.parallelize(Seq(
      (null, "a", "b", "c"), ("d", null, null, "e")
    )).toDF("x1", "x2", "x3", "x4")
    
    val cols = array($"x1", $"x2", $"x3", $"x4")
    val sep = lit("-")
    
    df.select(myConcat(cols, sep).alias("concatenated")).show
    
    // +------------+
    // |concatenated|
    // +------------+
    // |       a-b-c|
    // |         d-e|
    // +------------+
    

    With raw SQL:

    df.registerTempTable("df")
    sqlContext.udf.register("myConcat", myConcatFunc)
    
    sqlContext.sql(
        "SELECT myConcat(array(x1, x2, x4), '.') AS concatenated FROM df"
    ).show
    
    // +------------+
    // |concatenated|
    // +------------+
    // |         a.c|
    // |         d.e|
    // +------------+
    

    A slightly more complicated approach is not use UDF at all and compose SQL expressions with something roughly like this:

    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.Column
    
    def myConcatExpr(sep: String, cols: Column*) = regexp_replace(concat(
      cols.foldLeft(lit(""))(
        (acc, c) => when(c.isNotNull, concat(acc, c, lit(sep))).otherwise(acc)
      )
    ), s"($sep)?$$", "") 
    
    df.select(
      myConcatExpr("-", $"x1", $"x2", $"x3", $"x4").alias("concatenated")
    ).show
    // +------------+
    // |concatenated|
    // +------------+
    // |       a-b-c|
    // |         d-e|
    // +------------+
    

    but I doubt it is worth the effort unless you work with PySpark.


    * If you pass a function using varargs it will be stripped from all the syntactic sugar and resulting UDF will expect an ArrayType. For example:

    def f(s: String*) = s.mkString
    udf(f _)
    

    will be of type:

    UserDefinedFunction(<function1>,StringType,List(ArrayType(StringType,true)))