pythonpysparkcountspark-window-function

PySpark: count over a window with reset


I have a PySpark DataFrame which looks like this:

df = spark.createDataFrame(
    data=[
    (1, "GERMANY", "20230606", True),
    (2, "GERMANY", "20230620", False),
    (3, "GERMANY", "20230627", True),
    (4, "GERMANY", "20230705", True),
    (5, "GERMANY", "20230714", False),
    (6, "GERMANY", "20230715", True),
    ],
    schema=["ID", "COUNTRY", "DATE", "FLAG"]
)
df.show()
+---+-------+--------+-----+
| ID|COUNTRY|    DATE| FLAG|
+---+-------+--------+-----+
|  1|GERMANY|20230606| true|
|  2|GERMANY|20230620|false|
|  3|GERMANY|20230627| true|
|  4|GERMANY|20230705| true|
|  5|GERMANY|20230714|false|
|  6|GERMANY|20230715| true|
+---+-------+--------+-----+

The DataFrame has more countries. I want to create a new column COUNT_WITH_RESET following the logic:

This should be the output for the example above.

+---+-------+--------+-----+----------------+
| ID|COUNTRY|    DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
|  1|GERMANY|20230606| true|               1|
|  2|GERMANY|20230620|false|               0|
|  3|GERMANY|20230627| true|               1|
|  4|GERMANY|20230705| true|               2|
|  5|GERMANY|20230714|false|               0|
|  6|GERMANY|20230715| true|               1|
+---+-------+--------+-----+----------------+

I have tried with row_number() over a window but I can't manage to reset the count. I have also tried with .rowsBetween(Window.unboundedPreceding, Window.currentRow). Here's my approach:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_reset = Window.partitionBy("COUNTRY").orderBy("DATE")

df_with_reset = (
    df
    .withColumn("COUNT_WITH_RESET", F.when(~F.col("FLAG"), 0)
                .otherwise(F.row_number().over(window_reset)))
)

df_with_reset.show()
+---+-------+--------+-----+----------------+
| ID|COUNTRY|    DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
|  1|GERMANY|20230606| true|               1|
|  2|GERMANY|20230620|false|               0|
|  3|GERMANY|20230627| true|               3|
|  4|GERMANY|20230705| true|               4|
|  5|GERMANY|20230714|false|               0|
|  6|GERMANY|20230715| true|               6|
+---+-------+--------+-----+----------------+

This is obviously wrong as my window is partitioning only by country, but am I on the right track? Is there a specific built-in function in PySpark to achieve this? Do I need a UDF? Any help would be appreciated.


Solution

  • Partition the dataframe by COUNTRY then calculate the cumulative sum over the inverted FLAG column to assign group numbers in order to distinguish between different blocks of rows which start with false

    W1 = Window.partitionBy('COUNTRY').orderBy('DATE')
    df1 = df.withColumn('blocks', F.sum((~F.col('FLAG')).cast('long')).over(W1))
    
    df1.show()
    # +---+-------+--------+-----+------+
    # | ID|COUNTRY|    DATE| FLAG|blocks|
    # +---+-------+--------+-----+------+
    # |  1|GERMANY|20230606| true|     0|
    # |  2|GERMANY|20230620|false|     1|
    # |  3|GERMANY|20230627| true|     1|
    # |  4|GERMANY|20230705| true|     1|
    # |  5|GERMANY|20230714|false|     2|
    # |  6|GERMANY|20230715| true|     2|
    # +---+-------+--------+-----+------+
    

    Partition the dataframe by COUNTRY along with blocks then calculate row number over the ordered partition to create sequential counter

    W2 = Window.partitionBy('COUNTRY', 'blocks').orderBy('DATE')
    df1 = df1.withColumn('COUNT_WITH_RESET', F.row_number().over(W2) - 1)
    
    
    df1.show()
    # +---+-------+--------+-----+------+----------------+
    # | ID|COUNTRY|    DATE| FLAG|blocks|COUNT_WITH_RESET|
    # +---+-------+--------+-----+------+----------------+
    # |  1|GERMANY|20230606| true|     0|               0|
    # |  2|GERMANY|20230620|false|     1|               0|
    # |  3|GERMANY|20230627| true|     1|               1|
    # |  4|GERMANY|20230705| true|     1|               2|
    # |  5|GERMANY|20230714|false|     2|               0|
    # |  6|GERMANY|20230715| true|     2|               1|
    # +---+-------+--------+-----+------+----------------+