dataframeapache-sparkpysparkdelta-lake

Tricky pyspark transformation for merging rows based on timestamp durations


I have a delta lake table with a time column and a count (int) column. The dataframe rows need to be coalesced such that the resultant dataframe should have rows grouped for 2 days interval. The time difference of 2 days should always be the first row's time stamp to which other rows are coalesced to. Once a timestamp exceeds 172800 seconds, it should be taken as the new start time reference for the next grouping and so on. The count should be added to the grouped (coalesced) row.

For example, the original dataframe has below timestamps:

**time, count**
2019-02-18 11:03:55, 500
2019-02-18 11:06:18, 30
2019-02-18 11:07:58, 20
2019-02-18 11:07:58, 12
2019-02-18 11:08:38, 8
2019-02-18 11:10:29, 2
2019-02-20 11:09:12, 25
2019-02-20 11:10:10, 10
2019-04-02 10:10:10, 1
2019-04-05 10:10:10, 2
2019-04-09 10:10:09, 4
2019-04-11 10:10:30, 6
2019-04-13 10:10:10, 3
2019-04-16 10:10:10, 5
2019-04-19 10:10:10, 7
2019-04-21 10:10:10, 8

The expected result is:

**time => count_sum**
2019-02-18 11:03:55 => (500 + 30 + 20 + 12 + 8 + 2) = 572
2019-02-20 11:09:12 => (25 + 10) = 35
2019-04-02 10:10:10 => 1
2019-04-05 10:10:10 => 2
2019-04-09 10:10:09 => 4
2019-04-11 10:10:30 => (6+3) = 9
2019-04-16 10:10:10 => 5
2019-04-19 10:10:10 => (7+8) = 15

Any ideas to solve this?


Solution

  • The below code works fine for my dataset:

    import pyspark.sql.functions as F
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import col, min as min_
    from pyspark.sql.types import StructType, StructField, TimestampType, LongType, IntegerType
    from pyspark.sql.window import Window
    
    spark = SparkSession.builder.appName("DeltaLakeCoalesce").getOrCreate()
    
    # Sample data
    data = [
        (1, "2019-02-18 11:03:55", 500),
        (1, "2019-02-18 11:06:18", 30),
        (1, "2019-02-18 11:07:58", 20),
        (1, "2019-02-18 11:07:58", 12),
        (1, "2019-02-18 11:08:38", 8),
        (1, "2019-02-18 11:10:29", 2),
        (1, "2019-02-20 11:09:12", 25),
        (1, "2019-02-20 11:10:10", 10),
        (1, "2019-02-22 11:10:00", 10),
        (1, "2019-02-22 12:10:00", 4),
        (1, "2019-02-23 11:10:00", 5),
        (1, "2019-02-24 11:10:10", 6),
        (1, "2019-04-02 10:10:10", 1),
        (1, "2019-04-05 10:10:10", 2),
        (1, "2019-04-09 10:10:09", 4),
        (1, "2019-04-11 10:10:30", 6),
        (1, "2019-04-12 10:10:30", 4),
        (1, "2019-04-13 10:10:10", 3),
        (1, "2019-04-16 10:10:10", 5),
        (1, "2019-04-19 10:10:10", 7),
        (1, "2019-04-21 10:10:10", 8)
    ]
    
    df = spark.createDataFrame(data, ["hash", "time", "count"])
    df = df.withColumn("time", F.to_timestamp("time"))
    
    def assign_group_id(df, offset: int = 0):
        df = df.drop("group_id1", "time_diff2", "start_time", "time_diff", "group_id0")
        df = df.withColumn("time_unix", F.unix_timestamp(F.col("time")))
    
        # Create a temporary column to calculate time difference in seconds
        window_spec = Window.partitionBy("hash").orderBy("time_unix")
        df = df.withColumn("time_diff", col("time_unix") - (F.lag("time_unix").over(window_spec)))
    
        # Create a group id based on cumulative sum of when time_diff exceeds the threshold (172800 seconds)
        df = df.withColumn("group_id0", F.when(F.col("time_diff").isNull() | (F.col("time_diff") > 172800), 1).otherwise(0))
    
        df = df.withColumn("group_id", offset + F.sum(col("group_id0")).over(window_spec))
    
        df = df.withColumn("start_time", F.when(F.col("group_id0") == 1, col("time_unix")).otherwise(
            min_(col("time_unix")).over(Window.partitionBy("group_id").orderBy("time_unix").rowsBetween(Window.unboundedPreceding, -1))))
    
        df = df.withColumn("time_diff2", col("time_unix") - col("start_time"))
        df = df.withColumn("group_id1", F.when((F.col("time_diff2") > 172800), 1).otherwise(0))
    
        return df
    
    # Initialize variables
    offset = 0
    df_all = df
    schema = StructType([
        StructField("hash", IntegerType(), False),
        StructField("time", TimestampType(), True),
        StructField("count", LongType(), True),
        StructField("time_unix", LongType(), True),
        StructField("time_diff", LongType(), True),
        StructField("group_id0", IntegerType(), False),
        StructField("group_id", LongType(), True),
        StructField("start_time", LongType(), True),
        StructField("time_diff2", LongType(), True),
        StructField("group_id1", IntegerType(), False)
    ])
    df_correct = spark.createDataFrame([], schema)
    
    # Loop until all grouping is done correctly or 5 iterations
    for i in range(5):
        # Assign group IDs
        df_all = assign_group_id(df_all, offset)
        # Filter correctly grouped rows
        df_correct_new = df_all.filter(col("group_id1") == 0)
        # Union with previously correctly grouped rows
        df_correct = df_correct.unionByName(df_correct_new)
        df_correct.show(truncate=False, n=100)
        # Update DataFrame to only include rows that need to be regrouped
        df_all = df_all.filter(col("group_id1") == 1)
            
        # Update offset for the next iteration
        offset += 100000
    
    df_correct.orderBy("time_unix").show(truncate=False, n=100)
    
    # # Group by the newly created group_id and aggregate counts and first timestamp
    result_df = df_correct.groupBy("group_id").agg(
        F.first("time").alias("time"),
        F.sum("count").alias("count_sum")
    ).orderBy("time")
    
    result_df.select("time", "count_sum").show(truncate=False, n=100)