scalaapache-sparkdataframeapache-spark-sql

Filtering rows based on column values in Spark dataframe Scala


I have a dataframe (Spark):

id  value 
3     0
3     1
3     0
4     1
4     0
4     0

I want to create a new dataframe:

3 0
3 1
4 1

I need to remove all the rows after 1 (value) for each id. I tried with window functions in Spark dataframe (Scala) but couldn't find a solution. It seems like I am going in a wrong direction.

I am looking for a solution in Scala.

Output using monotonically_increasing_id:

 scala> val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data: org.apache.spark.sql.DataFrame = [id: int, value: int]

scala> val minIdx = dataWithIndex.filter($"value" === 1).groupBy($"id").agg(min($"idx")).toDF("r_id", "min_idx")
minIdx: org.apache.spark.sql.DataFrame = [r_id: int, min_idx: bigint]

scala> dataWithIndex.join(minIdx,($"r_id" === $"id") && ($"idx" <= $"min_idx")).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  4|    1|
+---+-----+

The solution won't work if we did a sorted transformation in the original dataframe. That time the monotonically_increasing_id() is generated based on original DF rather than sorted DF. I have missed that requirement before.

All suggestions are welcome.


Solution

  • One way is to use monotonically_increasing_id() and a self-join:

    val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
    data.show
    +---+-----+
    | id|value|
    +---+-----+
    |  3|    0|
    |  3|    1|
    |  3|    0|
    |  4|    1|
    |  4|    0|
    |  4|    0|
    +---+-----+
    

    Now we generate a column named idx with an increasing Long:

    val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
    // dataWithIndex.cache()
    

    Now we get the min(idx) for each id where value = 1:

    val minIdx = dataWithIndex
                   .filter($"value" === 1)
                   .groupBy($"id")
                   .agg(min($"idx"))
                   .toDF("r_id", "min_idx")
    

    Now we join the min(idx) back to the original DataFrame:

    dataWithIndex.join(
      minIdx,
      ($"r_id" === $"id") && ($"idx" <= $"min_idx")
    ).select($"id", $"value").show
    +---+-----+
    | id|value|
    +---+-----+
    |  3|    0|
    |  3|    1|
    |  4|    1|
    +---+-----+
    

    Note: monotonically_increasing_id() generates its value based on the partition of the row. This value may change each time dataWithIndex is re-evaluated. In my code above, because of lazy evaluation, it's only when I call the final show that monotonically_increasing_id() is evaluated.

    If you want to force the value to stay the same, for example so you can use show to evaluate the above step-by-step, uncomment this line above:

    //  dataWithIndex.cache()