pythonpysparkdatabricks

find overlapping rows in spark / python


I am working with PySpark to write the below routine without using a recursive approach. The problem with the recursive approach is that my data is too large (Millions of Rows) and writing a recursive routine is very expensive. The Rule is as follows.

Find overlapping process dates between subsequent rows sequentially. I have numbered the rows for sequencing already.

Steps:

  1. If process_start_date of previous row is <= process_end_date of next row and
  2. Process_end_date of previous row is >= process_start_date of next row

Then the rows are overlapping mark the is_overlap flag to 1 for both rows. I have done the coding successfully upto this step. But the next step of the rule is quite complex.

If the two rows are overlapping then, take the smallest(startdate) and largest(enddate) of the two overlapping rows and compare it to the third row, if the steps 1 & 2 are True for this comparison then the row 3 becomes part of the overlapping group along with rows 1 & 2 and we take the smallest and largest dates of the 3 rows and compare to the fourth row.

This process continues for the partition of the customerid and locationid as shown in the code below.

If the rows I and 2 do not overlap, then the comparison moves to rows 2 & 3 and perform the same steps.

There might be multiple separate overlapping groups inside a partition, so its possible that rows 1 & 2 are one group and rows 3 & 4 is another group and rows 5,6 are not part of the group.

For rows 5 & 6, they can either be part of the 2nd group, or not part of any group as this is a sequential comparison.

The result I want should be something like this

+--------+----------+----------+------------------+----------------+-------------+
|recordno|customerid|locationid|process_start_date|process_end_date|overlap_group|
+--------+----------+----------+------------------+----------------+-------------+
|       1|   2277953|         A|        2015-03-13|      2016-04-15|            1|
|       2|   2277953|         A|        2016-04-04|      2019-12-31|            1|
|       3|   2277953|         A|        2019-06-06|      2019-06-20|            1|
|       4|   2277953|         A|        2019-06-30|      2019-12-31|            1|
|       5|   2277953|         A|        2020-01-01|      2020-12-31|            2|
|       6|   2277953|         A|        2020-06-30|      2020-12-31|            2|
+--------+----------+----------+------------------+----------------+-------------+

As per my Logic Row 4 should be part of group 1, but it is appearing as ungrouped as 2.

My code is as follows

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag, when, sum as spark_sum
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("OptimizedOverlapGrouping").getOrCreate()

# Sample data
data = [
    (1, 2277953, 'A', '2015-03-13', '2016-04-15'),
    (2, 2277953, 'A', '2016-04-04', '2019-12-31'),
    (3, 2277953, 'A', '2019-06-06', '2019-06-20'),
    (4, 2277953, 'A', '2019-06-30', '2019-12-31'),
    (5, 2277953, 'A', '2020-01-01', '2020-12-31'),
    (6, 2277953, 'A', '2020-06-30', '2020-12-31')
 ]

# Create DataFrame
df = spark.createDataFrame(data, ['recordno', 'customerid', 'locationid', 'process_start_date', 'process_end_date'])
df = df.withColumn("process_start_date", col("process_start_date").cast("date"))
df = df.withColumn("process_end_date", col("process_end_date").cast("date"))

# Step 1: Define window for partitioning by group and ordering by start date
window_spec = Window.partitionBy("customerid", "locationid").orderBy("process_start_date")

# Step 2: Compare each row with the previous row to detect non-overlapping groups
df_with_lag = df.withColumn(
    "prev_EndDate", lag("process_end_date").over(window_spec)
)

# Step 3: Identify the start of a new overlap group
df_with_group_flag = df_with_lag.withColumn(
    "is_new_group",
    when(
        (col("prev_EndDate").isNull()) | (col("process_start_date") > col("prev_EndDate")),
        1
    ).otherwise(0)
)

# Step 4: Generate sequential group numbers for overlapping records
df_with_overlap_group = df_with_group_flag.withColumn(
    "overlap_group",
    spark_sum("is_new_group").over(window_spec)
).drop("prev_EndDate", "is_new_group")

df_with_overlap_group.show()

Solution

  • The problem with the current code in addressing your example is that it looks at only the immediate past row to see if there is an overlap. Instead, if you consider all the past rows within the group, it would be solved.

    Something like

    # Step 2: Compute the maximum `process_end_date` of the previous rows in the group
    df_with_max_prev_end_date = df.withColumn(
        "max_EndDate_prev",
        spark_max("process_end_date").over(window_spec.rowsBetween(Window.unboundedPreceding, Window.currentRow - 1))
    )
    
    # Step 3: Identify the start of a new group
    df_with_group_flag = df_with_max_prev_end_date.withColumn(
        "is_new_group",
        when(
            (col("max_EndDate_prev").isNull()) | (col("process_start_date") > col("max_EndDate_prev")),
            1
        ).otherwise(0)
    )
    
    # Step 4: Generate sequential group numbers for overlapping records
    df_with_overlap_group = df_with_group_flag.withColumn(
        "overlap_group",
        spark_sum("is_new_group").over(window_spec)
    ).drop("is_new_group", "max_EndDate_prev")
    
    # Show the final result
    df_with_overlap_group.show()