dataframeapache-sparkpysparkconditional-statementslead

PySpark lead based on condition


I have a dataset such as:

Condition | Date
0 | 2019/01/10
1 | 2019/01/11
0 | 2019/01/15
1 | 2019/01/16
1 | 2019/01/19
0 | 2019/01/23
0 | 2019/01/25
1 | 2019/01/29
1 | 2019/01/30

I would like to get the next value of the date column when condition == 1 was met.

The desired output would be something like:

Condition | Date | Lead
0 | 2019/01/10 | 2019/01/15
1 | 2019/01/11 | 2019/01/16
0 | 2019/01/15 | 2019/01/23
1 | 2019/01/16 | 2019/01/19
1 | 2019/01/19 | 2019/01/29
0 | 2019/01/23 | 2019/01/25
0 | 2019/01/25 | NaN
1 | 2019/01/29 | 2019/01/30
1 | 2019/01/30 | NaN

How can I perform that?

Please keep in mind it's a very large dataset - which I will have to partition and group by an UUID so the solution has to be somewhat performant.


Solution

  • To get the next value of the date column when condition == 1 was met, we can use the first window function with a when().otherwise() which emulates the lead.

    data_sdf. \
        withColumn('dt_w_cond1_lead', 
                   func.first(func.when(func.col('cond') == 1, func.col('dt')), ignorenulls=True).
                   over(wd.partitionBy().orderBy('dt').rowsBetween(1, sys.maxsize))
                   ). \
        show()
    
    # +----+----------+---------------+
    # |cond|        dt|dt_w_cond1_lead|
    # +----+----------+---------------+
    # |   0|2019-01-10|     2019-01-11|
    # |   1|2019-01-11|     2019-01-16|
    # |   0|2019-01-15|     2019-01-16|
    # |   1|2019-01-16|     2019-01-19|
    # |   1|2019-01-19|     2019-01-29|
    # |   0|2019-01-23|     2019-01-29|
    # |   0|2019-01-25|     2019-01-29|
    # |   1|2019-01-29|     2019-01-30|
    # |   1|2019-01-30|           null|
    # +----+----------+---------------+