dataframescalaapache-sparkapache-spark-sqlscala-spark

Add a tag to the list in the DataFrame based on the threshold given for the values ​in the list in Scala Spark


I have a Dataframe that has a column "grades" containing a list of Grade objects that have 2 fields: name (String) and value (Double). I would like to add the word PASS to the list of tags if there is a Grade on the list with the name: HOME and a minimum value of 20.0. Example below:

INPUT:
+------+-----+----+-------+-------------------------------------------------------------+
| model| cnd | age| tags  |  grades                                                     |
+------+-----+----+-------+-------------------------------------------------------------+
|  foo1|   xx|  10|  []   |   [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}]   | 
|  foo2|   xz|  12|  []   |   [{name:"ATW", value: 70.0}]   | 
|  foo3|   xc|  13|  []   |   [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}]    | 
+------+-----+----+-------+-------------------------------------------------------------+



 OUTPUT:

+------+-----+----+-------+--------------------------------------------------------------+
| model| cnd | age| tags  |  grades                                                     |
+------+-----+----+-------+--------------------------------------------------------------+
|  foo1|   xx|  10| [PASS]|   [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}]    | 
|  foo2|   xz|  12|  []   |   [{name:"ATW", value: 70.0}]                                | 
|  foo3|   xc|  13|  []   |   [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}]    | 
+------+-----+----+-------+--------------------------------------------------------------+

I haven't been able to find a reasonable solution. So far I have got this:

    dataFrame.withColumn("tags",
    when(
      array_contains(
        col("grades.name"),
        lit("HOME")
      ) && col("grades.value") >= lit(20.0),
      array_union(col("tags"), lit(Array("PASS")))
    ).otherwise(col("tags"))

But this code for some reason throws

org.apache.spark.sql.AnalysisException: cannot resolve '(`grades`.`value` >= 20.0D)' due to data type mismatch: differing types in '(`grades`.`value` >= 20.0D)' (array<double> and double).;;

The data is read from bigquery and there is no way that there is an array of double numbers in the value field.


Solution

  • Assume data is called your dataset (as below for the sake of simplicity):

    +----+---------------------------+
    |tags|grades                     |
    +----+---------------------------+
    |[]  |[{ATW, 10.0}, {HOME, 20.0}]|
    |[]  |[{ATW, 70.0}]              |
    |[]  |[{ATW, 90.0}, {HOME, 10.0}]|
    +----+---------------------------+
    

    If by any case your column (grades) is string, then we might want to convert the JSON to a structure as below (you can also skip this part):

    data = data.withColumn("grades",
      expr("from_json(grades, 'array<struct<name:string,value:double>>')")
    )
    

    Once this is in place, then we can apply the following:

    data = data.withColumn("tags",
      when(
        // when this condition is met, meaning that if there is one combo name = HOME and value >= 20
        expr("size(filter(grades, x -> x.name == 'HOME' and x.value >= 20))").geq(1),
        // concatenate whatever there is in TAGS column with array("pass")
        array_union(col("tags"), array(lit("PASS")))
        // otherwise, do not touch TAGS column
      ).otherwise(col("tags")))
    

    Final output looks like:

    +------+---------------------------+
    |tags  |grades                     |
    +------+---------------------------+
    |[PASS]|[{ATW, 10.0}, {HOME, 20.0}]|
    |[]    |[{ATW, 70.0}]              |
    |[]    |[{ATW, 90.0}, {HOME, 10.0}]|
    +------+---------------------------+
    

    Good luck!