dataframescalaapache-sparkapache-spark-sql

How to take first row in spark scala dataframe if if status column is greater than threshold


I have a dataframe as shown below. I'm loading raw data from an HBase table once every hour. If the status is equal to 1 for more than 10 minutes continuously, then I need to take the first row for every hourly batch.... Similarly, I need to do this for other IDs as well.

ID Timestamp status
1 2021-01-01 10:00:00 1
1 2021-01-01 10:01:06 1
1 2021-01-01 10:02:18 1
1 2021-01-01 10:03:24 1
1 2021-01-01 10:04:30 1
1 2021-01-01 10:05:36 1
1 2021-01-01 11:06:00 1
1 2021-01-01 11:07:06 1
1 2021-01-01 11:08:12 1
1 2021-01-01 11:09:24 1
1 2021-01-01 11:10:30 1
1 2021-01-01 11:11:36 1

Expecting output like below

ID Timestamp status
1 2021-01-01 10:00:00 1

Can you please help me on this

NOTE: If the status value is change within threshold, then no need to consider it. For example, as shown below:

ID Timestamp status
1 2021-01-01 10:00:00 1
1 2021-01-01 10:01:06 1
1 2021-01-01 10:02:18 1
1 2021-01-01 10:03:24 1
1 2021-01-01 10:04:30 1
1 2021-01-01 10:05:36 2
1 2021-01-01 11:06:00 3
1 2021-01-01 11:07:06 1
1 2021-01-01 11:08:12 1
1 2021-01-01 11:09:24 1
1 2021-01-01 11:10:30 1
1 2021-01-01 11:11:36 1

Solution

  • I think this looks like the islands and gaps SQL problem (and pretty sure there are cleaner, more idiomatic solutions):

    //build dataframe
    val data = Seq((1,"2021-01-01T10:00:00",1),(1,"2021-01-01T10:01:06",1),(1,"2021-01-01T10:02:18",1),(1,"2021-01-01T10:03:24",1),(1,"2021-01-01T10:04:30",1),(1,"2021-01-01T10:05:36",2),(1,"2021-01-01T11:06:00",3),(1,"2021-01-01T11:07:06",1),(1,"2021-01-01T11:08:12",1),(1,"2021-01-01T11:09:24",1),(1,"2021-01-01T11:10:30",1),(1,"2021-01-01T11:11:36",1))
    
    
    import org.apache.spark.sql.types._
    
    
    val schema = StructType(List(StructField("id", IntegerType, true), StructField("timestamp", TimestampType, true), StructField("status", IntegerType, true)))
    
    val df = spark.createDataFrame(data).toDF(schema.fieldNames: _*)
    
    df.show()
    +---+-------------------+------+
    | id|          timestamp|status|
    +---+-------------------+------+
    |  1|2021-01-01T10:00:00|     1|
    |  1|2021-01-01T10:01:06|     1|
    |  1|2021-01-01T10:02:18|     1|
    |  1|2021-01-01T10:03:24|     1|
    |  1|2021-01-01T10:04:30|     1|
    |  1|2021-01-01T10:05:36|     2|
    |  1|2021-01-01T11:06:00|     3|
    |  1|2021-01-01T11:07:06|     1|
    |  1|2021-01-01T11:08:12|     1|
    |  1|2021-01-01T11:09:24|     1|
    |  1|2021-01-01T11:10:30|     1|
    |  1|2021-01-01T11:11:36|     1|
    +---+-------------------+------+
    
    
    df.printSchema()
    root
     |-- id: integer (nullable = false)
     |-- timestamp: string (nullable = true)
     |-- status: integer (nullable = false)
     
     // identify contigous regions by status then get the duration of that status
     
     df.withColumn("id_sequence", row_number().over(Window.partitionBy(col("id")).orderBy(col("timestamp").asc))).
     withColumn("id_and_status", row_number().over(Window.partitionBy(col("id"), col("status")).orderBy(col("timestamp").asc))).
     withColumn("group", col("id_sequence") - col("id_and_status")).
     withColumn("group_max", max(col("timestamp")).over(Window.partitionBy(col("id"),col("group"))).cast(TimestampType)).
     withColumn("group_min", min(col("timestamp")).over(Window.partitionBy(col("id"),col("group"))).cast(TimestampType)).
     withColumn("difference_in_mins", (unix_timestamp(col("group_max")) - unix_timestamp(col("group_min"))) / 60).
     orderBy(col("id"), col("timestamp").asc).
     show()
     
    +---+-------------------+------+-----------+-------------+-----+-------------------+-------------------+------------------+
    | id|          timestamp|status|id_sequence|id_and_status|group|          group_max|          group_min|difference_in_mins|
    +---+-------------------+------+-----------+-------------+-----+-------------------+-------------------+------------------+
    |  1|2021-01-01T10:00:00|     1|          1|            1|    0|2021-01-01 10:04:30|2021-01-01 10:00:00|               4.5|
    |  1|2021-01-01T10:01:06|     1|          2|            2|    0|2021-01-01 10:04:30|2021-01-01 10:00:00|               4.5|
    |  1|2021-01-01T10:02:18|     1|          3|            3|    0|2021-01-01 10:04:30|2021-01-01 10:00:00|               4.5|
    |  1|2021-01-01T10:03:24|     1|          4|            4|    0|2021-01-01 10:04:30|2021-01-01 10:00:00|               4.5|
    |  1|2021-01-01T10:04:30|     1|          5|            5|    0|2021-01-01 10:04:30|2021-01-01 10:00:00|               4.5|
    |  1|2021-01-01T10:05:36|     2|          6|            1|    5|2021-01-01 10:05:36|2021-01-01 10:05:36|               0.0|
    |  1|2021-01-01T11:06:00|     3|          7|            1|    6|2021-01-01 11:06:00|2021-01-01 11:06:00|               0.0|
    |  1|2021-01-01T11:07:06|     1|          8|            6|    2|2021-01-01 11:11:36|2021-01-01 11:07:06|               4.5|
    |  1|2021-01-01T11:08:12|     1|          9|            7|    2|2021-01-01 11:11:36|2021-01-01 11:07:06|               4.5|
    |  1|2021-01-01T11:09:24|     1|         10|            8|    2|2021-01-01 11:11:36|2021-01-01 11:07:06|               4.5|
    |  1|2021-01-01T11:10:30|     1|         11|            9|    2|2021-01-01 11:11:36|2021-01-01 11:07:06|               4.5|
    |  1|2021-01-01T11:11:36|     1|         12|           10|    2|2021-01-01 11:11:36|2021-01-01 11:07:06|               4.5|
    +---+-------------------+------+-----------+-------------+-----+-------------------+-------------------+------------------+
    

    I didn't go any further but I think it is relatively simple to filter "difference_in_mins" for values greater than 10 mins and then grab the first record.

    Hope this helps.