I have a delta lake table with a time column and a count (int) column. The dataframe rows need to be coalesced such that the resultant dataframe should have rows grouped for 2 days interval. The time difference of 2 days should always be the first row's time stamp to which other rows are coalesced to. Once a timestamp exceeds 172800 seconds, it should be taken as the new start time reference for the next grouping and so on. The count should be added to the grouped (coalesced) row.
For example, the original dataframe has below timestamps:
**time, count**
2019-02-18 11:03:55, 500
2019-02-18 11:06:18, 30
2019-02-18 11:07:58, 20
2019-02-18 11:07:58, 12
2019-02-18 11:08:38, 8
2019-02-18 11:10:29, 2
2019-02-20 11:09:12, 25
2019-02-20 11:10:10, 10
2019-04-02 10:10:10, 1
2019-04-05 10:10:10, 2
2019-04-09 10:10:09, 4
2019-04-11 10:10:30, 6
2019-04-13 10:10:10, 3
2019-04-16 10:10:10, 5
2019-04-19 10:10:10, 7
2019-04-21 10:10:10, 8
The expected result is:
**time => count_sum**
2019-02-18 11:03:55 => (500 + 30 + 20 + 12 + 8 + 2) = 572
2019-02-20 11:09:12 => (25 + 10) = 35
2019-04-02 10:10:10 => 1
2019-04-05 10:10:10 => 2
2019-04-09 10:10:09 => 4
2019-04-11 10:10:30 => (6+3) = 9
2019-04-16 10:10:10 => 5
2019-04-19 10:10:10 => (7+8) = 15
Any ideas to solve this?
The below code works fine for my dataset:
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, min as min_
from pyspark.sql.types import StructType, StructField, TimestampType, LongType, IntegerType
from pyspark.sql.window import Window
spark = SparkSession.builder.appName("DeltaLakeCoalesce").getOrCreate()
# Sample data
data = [
(1, "2019-02-18 11:03:55", 500),
(1, "2019-02-18 11:06:18", 30),
(1, "2019-02-18 11:07:58", 20),
(1, "2019-02-18 11:07:58", 12),
(1, "2019-02-18 11:08:38", 8),
(1, "2019-02-18 11:10:29", 2),
(1, "2019-02-20 11:09:12", 25),
(1, "2019-02-20 11:10:10", 10),
(1, "2019-02-22 11:10:00", 10),
(1, "2019-02-22 12:10:00", 4),
(1, "2019-02-23 11:10:00", 5),
(1, "2019-02-24 11:10:10", 6),
(1, "2019-04-02 10:10:10", 1),
(1, "2019-04-05 10:10:10", 2),
(1, "2019-04-09 10:10:09", 4),
(1, "2019-04-11 10:10:30", 6),
(1, "2019-04-12 10:10:30", 4),
(1, "2019-04-13 10:10:10", 3),
(1, "2019-04-16 10:10:10", 5),
(1, "2019-04-19 10:10:10", 7),
(1, "2019-04-21 10:10:10", 8)
]
df = spark.createDataFrame(data, ["hash", "time", "count"])
df = df.withColumn("time", F.to_timestamp("time"))
def assign_group_id(df, offset: int = 0):
df = df.drop("group_id1", "time_diff2", "start_time", "time_diff", "group_id0")
df = df.withColumn("time_unix", F.unix_timestamp(F.col("time")))
# Create a temporary column to calculate time difference in seconds
window_spec = Window.partitionBy("hash").orderBy("time_unix")
df = df.withColumn("time_diff", col("time_unix") - (F.lag("time_unix").over(window_spec)))
# Create a group id based on cumulative sum of when time_diff exceeds the threshold (172800 seconds)
df = df.withColumn("group_id0", F.when(F.col("time_diff").isNull() | (F.col("time_diff") > 172800), 1).otherwise(0))
df = df.withColumn("group_id", offset + F.sum(col("group_id0")).over(window_spec))
df = df.withColumn("start_time", F.when(F.col("group_id0") == 1, col("time_unix")).otherwise(
min_(col("time_unix")).over(Window.partitionBy("group_id").orderBy("time_unix").rowsBetween(Window.unboundedPreceding, -1))))
df = df.withColumn("time_diff2", col("time_unix") - col("start_time"))
df = df.withColumn("group_id1", F.when((F.col("time_diff2") > 172800), 1).otherwise(0))
return df
# Initialize variables
offset = 0
df_all = df
schema = StructType([
StructField("hash", IntegerType(), False),
StructField("time", TimestampType(), True),
StructField("count", LongType(), True),
StructField("time_unix", LongType(), True),
StructField("time_diff", LongType(), True),
StructField("group_id0", IntegerType(), False),
StructField("group_id", LongType(), True),
StructField("start_time", LongType(), True),
StructField("time_diff2", LongType(), True),
StructField("group_id1", IntegerType(), False)
])
df_correct = spark.createDataFrame([], schema)
# Loop until all grouping is done correctly or 5 iterations
for i in range(5):
# Assign group IDs
df_all = assign_group_id(df_all, offset)
# Filter correctly grouped rows
df_correct_new = df_all.filter(col("group_id1") == 0)
# Union with previously correctly grouped rows
df_correct = df_correct.unionByName(df_correct_new)
df_correct.show(truncate=False, n=100)
# Update DataFrame to only include rows that need to be regrouped
df_all = df_all.filter(col("group_id1") == 1)
# Update offset for the next iteration
offset += 100000
df_correct.orderBy("time_unix").show(truncate=False, n=100)
# # Group by the newly created group_id and aggregate counts and first timestamp
result_df = df_correct.groupBy("group_id").agg(
F.first("time").alias("time"),
F.sum("count").alias("count_sum")
).orderBy("time")
result_df.select("time", "count_sum").show(truncate=False, n=100)