I have a PySpark DataFrame which looks like this:
df = spark.createDataFrame(
data=[
(1, "GERMANY", "20230606", True),
(2, "GERMANY", "20230620", False),
(3, "GERMANY", "20230627", True),
(4, "GERMANY", "20230705", True),
(5, "GERMANY", "20230714", False),
(6, "GERMANY", "20230715", True),
],
schema=["ID", "COUNTRY", "DATE", "FLAG"]
)
df.show()
+---+-------+--------+-----+
| ID|COUNTRY| DATE| FLAG|
+---+-------+--------+-----+
| 1|GERMANY|20230606| true|
| 2|GERMANY|20230620|false|
| 3|GERMANY|20230627| true|
| 4|GERMANY|20230705| true|
| 5|GERMANY|20230714|false|
| 6|GERMANY|20230715| true|
+---+-------+--------+-----+
The DataFrame has more countries. I want to create a new column COUNT_WITH_RESET
following the logic:
FLAG=False
, then COUNT_WITH_RESET=0
.FLAG=True
, then COUNT_WITH_RESET
should count the number of rows starting from the previous date where FLAG=False
for that specific country.This should be the output for the example above.
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 1|
| 4|GERMANY|20230705| true| 2|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 1|
+---+-------+--------+-----+----------------+
I have tried with row_number()
over a window but I can't manage to reset the count. I have also tried with .rowsBetween(Window.unboundedPreceding, Window.currentRow)
. Here's my approach:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window_reset = Window.partitionBy("COUNTRY").orderBy("DATE")
df_with_reset = (
df
.withColumn("COUNT_WITH_RESET", F.when(~F.col("FLAG"), 0)
.otherwise(F.row_number().over(window_reset)))
)
df_with_reset.show()
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 3|
| 4|GERMANY|20230705| true| 4|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 6|
+---+-------+--------+-----+----------------+
This is obviously wrong as my window is partitioning only by country, but am I on the right track? Is there a specific built-in function in PySpark to achieve this? Do I need a UDF? Any help would be appreciated.
Partition the dataframe by COUNTRY
then calculate the cumulative sum over the inverted FLAG
column to assign group numbers in order to distinguish between different blocks
of rows which start with false
W1 = Window.partitionBy('COUNTRY').orderBy('DATE')
df1 = df.withColumn('blocks', F.sum((~F.col('FLAG')).cast('long')).over(W1))
df1.show()
# +---+-------+--------+-----+------+
# | ID|COUNTRY| DATE| FLAG|blocks|
# +---+-------+--------+-----+------+
# | 1|GERMANY|20230606| true| 0|
# | 2|GERMANY|20230620|false| 1|
# | 3|GERMANY|20230627| true| 1|
# | 4|GERMANY|20230705| true| 1|
# | 5|GERMANY|20230714|false| 2|
# | 6|GERMANY|20230715| true| 2|
# +---+-------+--------+-----+------+
Partition the dataframe by COUNTRY
along with blocks
then calculate row number over the ordered partition to create sequential counter
W2 = Window.partitionBy('COUNTRY', 'blocks').orderBy('DATE')
df1 = df1.withColumn('COUNT_WITH_RESET', F.row_number().over(W2) - 1)
df1.show()
# +---+-------+--------+-----+------+----------------+
# | ID|COUNTRY| DATE| FLAG|blocks|COUNT_WITH_RESET|
# +---+-------+--------+-----+------+----------------+
# | 1|GERMANY|20230606| true| 0| 0|
# | 2|GERMANY|20230620|false| 1| 0|
# | 3|GERMANY|20230627| true| 1| 1|
# | 4|GERMANY|20230705| true| 1| 2|
# | 5|GERMANY|20230714|false| 2| 0|
# | 6|GERMANY|20230715| true| 2| 1|
# +---+-------+--------+-----+------+----------------+