I am working with PySpark to write the below routine without using a recursive approach. The problem with the recursive approach is that my data is too large (Millions of Rows) and writing a recursive routine is very expensive. The Rule is as follows.
Find overlapping process dates between subsequent rows sequentially. I have numbered the rows for sequencing already.
Steps:
Then the rows are overlapping mark the is_overlap flag to 1 for both rows. I have done the coding successfully upto this step. But the next step of the rule is quite complex.
If the two rows are overlapping then, take the smallest(startdate) and largest(enddate) of the two overlapping rows and compare it to the third row, if the steps 1 & 2 are True for this comparison then the row 3 becomes part of the overlapping group along with rows 1 & 2 and we take the smallest and largest dates of the 3 rows and compare to the fourth row.
This process continues for the partition of the customerid and locationid as shown in the code below.
If the rows I and 2 do not overlap, then the comparison moves to rows 2 & 3 and perform the same steps.
There might be multiple separate overlapping groups inside a partition, so its possible that rows 1 & 2 are one group and rows 3 & 4 is another group and rows 5,6 are not part of the group.
For rows 5 & 6, they can either be part of the 2nd group, or not part of any group as this is a sequential comparison.
The result I want should be something like this
+--------+----------+----------+------------------+----------------+-------------+
|recordno|customerid|locationid|process_start_date|process_end_date|overlap_group|
+--------+----------+----------+------------------+----------------+-------------+
| 1| 2277953| A| 2015-03-13| 2016-04-15| 1|
| 2| 2277953| A| 2016-04-04| 2019-12-31| 1|
| 3| 2277953| A| 2019-06-06| 2019-06-20| 1|
| 4| 2277953| A| 2019-06-30| 2019-12-31| 1|
| 5| 2277953| A| 2020-01-01| 2020-12-31| 2|
| 6| 2277953| A| 2020-06-30| 2020-12-31| 2|
+--------+----------+----------+------------------+----------------+-------------+
As per my Logic Row 4 should be part of group 1, but it is appearing as ungrouped as 2.
My code is as follows
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag, when, sum as spark_sum
from pyspark.sql.window import Window
# Initialize Spark session
spark = SparkSession.builder.appName("OptimizedOverlapGrouping").getOrCreate()
# Sample data
data = [
(1, 2277953, 'A', '2015-03-13', '2016-04-15'),
(2, 2277953, 'A', '2016-04-04', '2019-12-31'),
(3, 2277953, 'A', '2019-06-06', '2019-06-20'),
(4, 2277953, 'A', '2019-06-30', '2019-12-31'),
(5, 2277953, 'A', '2020-01-01', '2020-12-31'),
(6, 2277953, 'A', '2020-06-30', '2020-12-31')
]
# Create DataFrame
df = spark.createDataFrame(data, ['recordno', 'customerid', 'locationid', 'process_start_date', 'process_end_date'])
df = df.withColumn("process_start_date", col("process_start_date").cast("date"))
df = df.withColumn("process_end_date", col("process_end_date").cast("date"))
# Step 1: Define window for partitioning by group and ordering by start date
window_spec = Window.partitionBy("customerid", "locationid").orderBy("process_start_date")
# Step 2: Compare each row with the previous row to detect non-overlapping groups
df_with_lag = df.withColumn(
"prev_EndDate", lag("process_end_date").over(window_spec)
)
# Step 3: Identify the start of a new overlap group
df_with_group_flag = df_with_lag.withColumn(
"is_new_group",
when(
(col("prev_EndDate").isNull()) | (col("process_start_date") > col("prev_EndDate")),
1
).otherwise(0)
)
# Step 4: Generate sequential group numbers for overlapping records
df_with_overlap_group = df_with_group_flag.withColumn(
"overlap_group",
spark_sum("is_new_group").over(window_spec)
).drop("prev_EndDate", "is_new_group")
df_with_overlap_group.show()
The problem with the current code in addressing your example is that it looks at only the immediate past row to see if there is an overlap. Instead, if you consider all the past rows within the group, it would be solved.
Something like
# Step 2: Compute the maximum `process_end_date` of the previous rows in the group
df_with_max_prev_end_date = df.withColumn(
"max_EndDate_prev",
spark_max("process_end_date").over(window_spec.rowsBetween(Window.unboundedPreceding, Window.currentRow - 1))
)
# Step 3: Identify the start of a new group
df_with_group_flag = df_with_max_prev_end_date.withColumn(
"is_new_group",
when(
(col("max_EndDate_prev").isNull()) | (col("process_start_date") > col("max_EndDate_prev")),
1
).otherwise(0)
)
# Step 4: Generate sequential group numbers for overlapping records
df_with_overlap_group = df_with_group_flag.withColumn(
"overlap_group",
spark_sum("is_new_group").over(window_spec)
).drop("is_new_group", "max_EndDate_prev")
# Show the final result
df_with_overlap_group.show()