pysparkspark-window-function

Window function based on a condition


I have the following DF:

|-----------------------|
|Date       | Val | Cond|
|-----------------------|
|2022-01-08 | 2   | 0   |
|2022-01-09 | 4   | 1   |
|2022-01-10 | 6   | 1   |
|2022-01-11 | 8   | 0   |
|2022-01-12 | 2   | 1   |
|2022-01-13 | 5   | 1   |
|2022-01-14 | 7   | 0   |
|2022-01-15 | 9   | 0   | 
|-----------------------|

I need to sum the values of two days before where cond = 1 for every date, my expected output is:

|-----------------|
|Date       | Sum |
|-----------------|
|2022-01-08 | 0   |  Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-09 | 0   |  Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-10 | 0   |  Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-11 | 10  | (4+6)
|2022-01-12 | 10  | (4+6)
|2022-01-13 | 8   | (2+6)
|2022-01-14 | 7   | (5+2)
|2022-01-15 | 7   | (5+2)
|-----------------|

I've tried to get the output DF using this code:

df = df.where("Cond= 1").withColumn(
    "ListView",
    f.collect_list("Val").over(windowSpec.rowsBetween(-2, -1))
)

But when I use .where("Cond = 1") I exclude the dates that cond is equal zero.

I found the following answer but didn't help me:

Window.rowsBetween - only consider rows fulfilling a specific condition (e.g. not being null)

How can I achieve my expected output using window functions?

The MVCE:

data_1=[
    ("2022-01-08",2,0),
    ("2022-01-09",4,1),
    ("2022-01-10",6,1),
    ("2022-01-11",8,0),
    ("2022-01-12",2,1),
    ("2022-01-13",5,1),
    ("2022-01-14",7,0),
    ("2022-01-15",9,0) 
]

schema_1 = StructType([
    StructField("Date", DateType(),True),
    StructField("Val", IntegerType(),True),
    StructField("Cond", IntegerType(),True)
  ])

df_1 = spark.createDataFrame(data=data_1,schema=schema_1)

Solution

  • The following should do the trick (but I'm sure it can be further optimized).

    Setup:

    data_1=[
        ("2022-01-08",2,0),
        ("2022-01-09",4,1),
        ("2022-01-10",6,1),
        ("2022-01-11",8,0),
        ("2022-01-12",2,1),
        ("2022-01-13",5,1),
        ("2022-01-14",7,0),
        ("2022-01-15",9,0),
        ("2022-01-16",9,0),
        ("2022-01-17",9,0)
    ]
    
    schema_1 = StructType([
        StructField("Date", StringType(),True),
        StructField("Val", IntegerType(),True),
        StructField("Cond", IntegerType(),True)
      ])
    
    df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
    df_1 = df_1.withColumn('Date', to_date("Date", "yyyy-MM-dd"))
    
    +----------+---+----+
    |      Date|Val|Cond|
    +----------+---+----+
    |2022-01-08|  2|   0|
    |2022-01-09|  4|   1|
    |2022-01-10|  6|   1|
    |2022-01-11|  8|   0|
    |2022-01-12|  2|   1|
    |2022-01-13|  5|   1|
    |2022-01-14|  7|   0|
    |2022-01-15|  9|   0|
    |2022-01-16|  9|   0|
    |2022-01-17|  9|   0|
    +----------+---+----+
    

    Create a new DF only with Cond==1 rows to obtain the sum of two consecutive rows with that condition:

    windowSpec = Window.partitionBy("Cond").orderBy("Date")
    df_2 = df_1.where(df_1.Cond==1).withColumn(
        "Sum",
        sum("Val").over(windowSpec.rowsBetween(-1, 0))
    ).withColumn('date_1', col('date')).drop('date')
    
    +---+----+---+----------+
    |Val|Cond|Sum|    date_1|
    +---+----+---+----------+
    |  4|   1|  4|2022-01-09|
    |  6|   1| 10|2022-01-10|
    |  2|   1|  8|2022-01-12|
    |  5|   1|  7|2022-01-13|
    +---+----+---+----------+
    

    Do a left join to get the sum into the original data frame, and set the sum to zero for the rows with Cond==0:

    df_3 = df_1.join(df_2.select('sum', col('date_1')), df_1.Date == df_2.date_1, "left").drop('date_1').fillna(0)
    
    +----------+---+----+---+
    |      Date|Val|Cond|sum|
    +----------+---+----+---+
    |2022-01-08|  2|   0|  0|
    |2022-01-09|  4|   1|  4|
    |2022-01-10|  6|   1| 10|
    |2022-01-11|  8|   0|  0|
    |2022-01-12|  2|   1|  8|
    |2022-01-13|  5|   1|  7|
    |2022-01-14|  7|   0|  0|
    |2022-01-15|  9|   0|  0|
    |2022-01-16|  9|   0|  0|
    |2022-01-17|  9|   0|  0|
    +----------+---+----+---+
    

    Do a cumulative sum on the condition column:

    df_3=df_3.withColumn('cond_sum', sum('cond').over(Window.orderBy('Date')))
    
    +----------+---+----+---+--------+
    |      Date|Val|Cond|sum|cond_sum|
    +----------+---+----+---+--------+
    |2022-01-08|  2|   0|  0|       0|
    |2022-01-09|  4|   1|  4|       1|
    |2022-01-10|  6|   1| 10|       2|
    |2022-01-11|  8|   0|  0|       2|
    |2022-01-12|  2|   1|  8|       3|
    |2022-01-13|  5|   1|  7|       4|
    |2022-01-14|  7|   0|  0|       4|
    |2022-01-15|  9|   0|  0|       4|
    |2022-01-16|  9|   0|  0|       4|
    |2022-01-17|  9|   0|  0|       4|
    +----------+---+----+---+--------+
    

    Finally, for each partition where the cond_sum is greater than 1, use the max sum for that partition:

    df_3.withColumn('sum', when(df_3.cond_sum > 1, max('sum').over(Window.partitionBy('cond_sum'))).otherwise(0)).show()
    
    +----------+---+----+---+--------+
    |      Date|Val|Cond|sum|cond_sum|
    +----------+---+----+---+--------+
    |2022-01-08|  2|   0|  0|       0|
    |2022-01-09|  4|   1|  0|       1|
    |2022-01-10|  6|   1| 10|       2|
    |2022-01-11|  8|   0| 10|       2|
    |2022-01-12|  2|   1|  8|       3|
    |2022-01-13|  5|   1|  7|       4|
    |2022-01-14|  7|   0|  7|       4|
    |2022-01-15|  9|   0|  7|       4|
    |2022-01-16|  9|   0|  7|       4|
    |2022-01-17|  9|   0|  7|       4|
    +----------+---+----+---+--------+