pysparkapache-spark-sql

How to get the next Non Null value within a group in Pyspark


I have a sales transaction dataframe which looks as follows:

id date amount last_order_date
001 2021-01 100 2020-11
001 2021-02 0 null
001 2021-03 0 null
001 2021-04 20 2021-01
001 2021-05 0 null
001 2021-06 0 null
001 2021-07 0 null
001 2021-08 50 2021-04
002 2022-03 5 2022-01
002 2022-04 40 2022-03
002 2022-05 0 null
002 2022-06 0 null
002 2022-07 0 null
002 2022-08 35 2022-04

I want to replace the null values in column last_order_date by the next non null value of that same column for each group (id) such that I get the dataframe;

id date amount last_order_date
001 2021-01 100 2020-11
001 2021-02 0 2021-01
001 2021-03 0 2021-01
001 2021-04 20 2021-01
001 2021-05 0 2021-04
001 2021-06 0 2021-04
001 2021-07 0 2021-04
001 2021-08 50 2021-04
002 2022-03 5 2022-01
002 2022-04 40 2022-03
002 2022-05 0 2022-04
002 2022-06 0 2022-04
002 2022-07 0 2022-04
002 2022-08 35 2022-04

For what I have found the method lies on using the last/first function while using a window partitioned by id. However, when I apply the following code:

df.withColumn('last_order_date', F.last('last_order_date', ignorenulls = True).over(Window.partitionBy('id').orderBy('date')))

I get the null values replaced by the previous non null value

id date amount last_order_date
001 2021-01 100 2020-11
001 2021-02 0 2020-11
001 2021-03 0 2020-11
001 2021-04 20 2021-01
001 2021-05 0 2021-01
001 2021-06 0 2021-01
001 2021-07 0 2021-01
001 2021-08 50 2021-04
002 2022-03 5 2022-01
002 2022-04 40 2022-03
002 2022-05 0 2022-03
002 2022-06 0 2022-03
002 2022-07 0 2022-03
002 2022-08 35 2022-04

I am not quite sure where the problem lies. Thanks in advance for your help.


Solution

  • you're very close. if you use a first with your window, but make a sliding window, you can achieve your required result.

    import sys
    import pyspark.sql.functions as func
    from pyspark.sql.window import Window as wd
    
    data_sdf. \
        withColumn('last_order_dt_filled', 
                   func.first('last_order_dt', ignorenulls=True).
                   over(wd.partitionBy('id').orderBy('dt').rowsBetween(0, sys.maxsize))
                   ). \
        show()
    
    # +---+-------+---+-------------+--------------------+
    # | id|     dt|amt|last_order_dt|last_order_dt_filled|
    # +---+-------+---+-------------+--------------------+
    # |001|2021-01|100|      2020-11|             2020-11|
    # |001|2021-02|  0|         null|             2021-01|
    # |001|2021-03|  0|         null|             2021-01|
    # |001|2021-04| 20|      2021-01|             2021-01|
    # |001|2021-05|  0|         null|             2021-04|
    # |001|2021-06|  0|         null|             2021-04|
    # |001|2021-07|  0|         null|             2021-04|
    # |001|2021-08| 50|      2021-04|             2021-04|
    # |002|2022-03|  5|      2022-01|             2022-01|
    # |002|2022-04| 40|      2022-03|             2022-03|
    # |002|2022-05|  0|         null|             2022-04|
    # |002|2022-06|  0|         null|             2022-04|
    # |002|2022-07|  0|         null|             2022-04|
    # |002|2022-08| 35|      2022-04|             2022-04|
    # +---+-------+---+-------------+--------------------+
    

    The idea is to get the first non-null value starting from the current row to end of the group.