pythonpysparkaggregatespark-window-function

Pyspak - calculate median value with a sliding time window


I have the following data frame in pyspark:

date user_country account_type num_listens
2022-08-01 UK premium 32
2022-08-01 DE free 64
2022-08-01 FR free 93
2022-08-01 UK free 51
2022-08-02 UK premium 26
2022-08-02 FR free 34
2022-08-02 DE free 29
2022-08-02 DE premium 41
2022-08-02 DE free 12
2022-08-02 FR premium 31
2022-08-03 FR free 55
2022-08-03 UK premium 38
2022-08-03 UK premium 51
2022-08-03 FR free 81
2022-08-04 DE free 6
2022-08-04 UK premium 97
2022-08-04 FR free 33
2022-08-04 UK premium 41
2022-08-04 FR premium 67
2022-08-04 DE free 86
2022-08-04 DE free 25
2022-08-04 FR free 16
2022-08-04 FR free 48
2022-08-04 UK premium 11
2022-08-04 UK free 24
2022-08-05 DE free 95
2022-08-05 FR free 68
2022-08-05 DE premium 23
2022-08-05 UK free 79
2022-08-05 UK free 41
2022-08-05 DE premium 99
columns = ["date", "user_country","account_type", "num_listens"]
data = [("2022-08-01", "UK", "premium", "32"),
        ("2022-08-01", "DE", "free", "64"),
        ("2022-08-01", "FR", "free", "93"),
        ("2022-08-01", "UK", "free", "51"),
        ("2022-08-02", "UK", "premium", "26"),
        ("2022-08-02", "FR", "free", "34"),
        ("2022-08-02", "DE", "free", "29"),
        ("2022-08-02", "DE", "premium", "41"),
        ("2022-08-02", "DE", "free", "12"),
        ("2022-08-02", "FR", "premium", "31"),
        ("2022-08-03", "FR", "free", "55"),
        ("2022-08-03", "UK", "premium", "38"),
        ("2022-08-03", "UK", "premium", "51"),
        ("2022-08-03", "FR", "free", "81"),
        ("2022-08-04", "DE", "free", "6"),
        ("2022-08-04", "UK", "premium", "97"),
        ("2022-08-04", "FR", "free", "33"),
        ("2022-08-04", "UK", "premium", "41"),
        ("2022-08-04", "FR", "premium", "67"),
        ("2022-08-04", "DE", "free", "86"),
        ("2022-08-04", "DE", "free", "25"),
        ("2022-08-04", "FR", "free", "16"),
        ("2022-08-04", "FR", "free", "48"),
        ("2022-08-04", "UK", "premium", "11"),
        ("2022-08-04", "UK", "free", "24"),
        ("2022-08-05", "DE", "free", "95"),
        ("2022-08-05", "FR", "free", "68"),
        ("2022-08-05", "DE", "premium", "23"),
        ("2022-08-05", "UK", "free", "79"),
        ("2022-08-05", "UK", "free", "41"),
        ("2022-08-05", "DE", "premium", "99")        
       ]

I'm trying to group this data by user_country, account_type and num_listens, always calculating the median value for each group. On top of this I would like to use a sliding time window to restrict the data I use for each aggregation. For example, when calculating the median value on 2022-08-04, I would only like to use data from the ten dates prior.

The resulting table should look as follows:

snapshot_date user_country account_type median
2022-08-06 UK premium 38
2022-08-06 DE free 29
2022-08-06 FR free 52
2022-08-06 UK free 46
2022-08-06 DE premium 41
2022-08-06 FR premium 49
2022-08-05 UK premium 38
2022-08-05 DE free 27
2022-08-05 FR free 48
2022-08-05 UK free 38
2022-08-05 DE premium 41
2022-08-05 FR premium 49
2022-08-04 UK premium 35
2022-08-04 DE free 29
2022-08-04 FR free 68
2022-08-04 UK free 51
2022-08-04 DE premium 41
2022-08-04 FR premium 31
2022-08-03 UK premium 29
2022-08-03 DE free 29
2022-08-03 FR free 64
2022-08-03 UK free 51
2022-08-03 DE premium 41
2022-08-03 FR premium 31
2022-08-02 UK premium 32
2022-08-02 DE free 64
2022-08-02 FR free 93
2022-08-02 UK free 51

The value in the first row would be the median number of listens for all UK users with the premium account, using data from the previous 10 days (I only included a small sample of 5 days so in this specific case there would not be the full desired rang of 10 days available).

Any help on how this can be achieved in pyspark would be much appreciated. I've been fiddling around with combining a group by with a window function but have been unable to get the desired result.


Solution

  • You can collect the values in an array, and then apply the median logic on that.

    For simplicity, I'll calculate the median of a window of 4 dates using your sample data. This is considering you don't want a continuity in the dates, i.e. previous 3 dates may or may not be in sequence ([2022-01-01, 2022-01-03, 2022-01-04, 2022-01-04] is also acceptable).

    data_sdf. \
        withColumn('num_listens_arr', 
                   func.array_sort(func.collect_list('num_listens').
                                   over(wd.partitionBy('user_country', 'account_type').orderBy('date').rowsBetween(-3, 0))
                                   )
                   ). \
        withColumn('median', 
                   func.when(func.size('num_listens_arr') % 2 == 0, 
                             func.expr('(num_listens_arr[int(size(num_listens_arr) / 2)-1] + num_listens_arr[int(size(num_listens_arr) / 2)]) / 2').cast('double')
                             ).
                   otherwise(func.expr('num_listens_arr[int(size(num_listens_arr) / 2)]').cast('double'))
                   ). \
        show(data_sdf.count())
    
    # +----------+------------+------------+-----------+----------------+------+
    # |      date|user_country|account_type|num_listens| num_listens_arr|median|
    # +----------+------------+------------+-----------+----------------+------+
    # |2022-08-01|          UK|        free|         51|            [51]|  51.0|
    # |2022-08-04|          UK|        free|         24|        [24, 51]|  37.5|
    # |2022-08-05|          UK|        free|         79|    [24, 51, 79]|  51.0|
    # |2022-08-05|          UK|        free|         41|[24, 41, 51, 79]|  46.0|
    # |2022-08-01|          UK|     premium|         32|            [32]|  32.0|
    # |2022-08-02|          UK|     premium|         26|        [26, 32]|  29.0|
    # |2022-08-03|          UK|     premium|         38|    [26, 32, 38]|  32.0|
    # |2022-08-03|          UK|     premium|         51|[26, 32, 38, 51]|  35.0|
    # |2022-08-04|          UK|     premium|         97|[26, 38, 51, 97]|  44.5|
    # |2022-08-04|          UK|     premium|         41|[38, 41, 51, 97]|  46.0|
    # |2022-08-04|          UK|     premium|         11|[11, 41, 51, 97]|  46.0|
    # |2022-08-02|          DE|     premium|         41|            [41]|  41.0|
    # |2022-08-05|          DE|     premium|         23|        [23, 41]|  32.0|
    # |2022-08-05|          DE|     premium|         99|    [23, 41, 99]|  41.0|
    # |2022-08-01|          DE|        free|         64|            [64]|  64.0|
    # |2022-08-02|          DE|        free|         29|        [29, 64]|  46.5|
    # |2022-08-02|          DE|        free|         12|    [12, 29, 64]|  29.0|
    # |2022-08-04|          DE|        free|          6| [6, 12, 29, 64]|  20.5|
    # |2022-08-04|          DE|        free|         86| [6, 12, 29, 86]|  20.5|
    # |2022-08-04|          DE|        free|         25| [6, 12, 25, 86]|  18.5|
    # |2022-08-05|          DE|        free|         95| [6, 25, 86, 95]|  55.5|
    # |2022-08-01|          FR|        free|         93|            [93]|  93.0|
    # |2022-08-02|          FR|        free|         34|        [34, 93]|  63.5|
    # |2022-08-03|          FR|        free|         55|    [34, 55, 93]|  55.0|
    # |2022-08-03|          FR|        free|         81|[34, 55, 81, 93]|  68.0|
    # |2022-08-04|          FR|        free|         33|[33, 34, 55, 81]|  44.5|
    # |2022-08-04|          FR|        free|         16|[16, 33, 55, 81]|  44.0|
    # |2022-08-04|          FR|        free|         48|[16, 33, 48, 81]|  40.5|
    # |2022-08-05|          FR|        free|         68|[16, 33, 48, 68]|  40.5|
    # |2022-08-02|          FR|     premium|         31|            [31]|  31.0|
    # |2022-08-04|          FR|     premium|         67|        [31, 67]|  49.0|
    # +----------+------------+------------+-----------+----------------+------+
    

    If you do want to maintain the sequence of dates, you can use a rangeBetween().

    data_sdf. \
        withColumn('dt_long', func.col('date').cast('timestamp').cast('long')). \
        withColumn('num_listens_arr', 
                   func.array_sort(func.collect_list('num_listens').
                                   over(wd.partitionBy('user_country', 'account_type').orderBy('dt_long').rangeBetween(-3*24*60*60, 0))
                                   )
                   ). \
        withColumn('median', 
                   func.when(func.size('num_listens_arr') % 2 == 0, 
                             func.expr('(num_listens_arr[int(size(num_listens_arr) / 2)-1] + num_listens_arr[int(size(num_listens_arr) / 2)]) / 2').cast('double')
                             ).
                   otherwise(func.expr('num_listens_arr[int(size(num_listens_arr) / 2)]').cast('double'))
                   ). \
        show(data_sdf.count(), truncate=False)
    
    # +----------+------------+------------+-----------+----------+----------------------------+------+
    # |date      |user_country|account_type|num_listens|dt_long   |num_listens_arr             |median|
    # +----------+------------+------------+-----------+----------+----------------------------+------+
    # |2022-08-01|UK          |free        |51         |1659312000|[51]                        |51.0  |
    # |2022-08-04|UK          |free        |24         |1659571200|[24, 51]                    |37.5  |
    # |2022-08-05|UK          |free        |79         |1659657600|[24, 41, 79]                |41.0  |
    # |2022-08-05|UK          |free        |41         |1659657600|[24, 41, 79]                |41.0  |
    # |2022-08-01|UK          |premium     |32         |1659312000|[32]                        |32.0  |
    # |2022-08-02|UK          |premium     |26         |1659398400|[26, 32]                    |29.0  |
    # |2022-08-03|UK          |premium     |38         |1659484800|[26, 32, 38, 51]            |35.0  |
    # |2022-08-03|UK          |premium     |51         |1659484800|[26, 32, 38, 51]            |35.0  |
    # |2022-08-04|UK          |premium     |97         |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0  |
    # |2022-08-04|UK          |premium     |41         |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0  |
    # |2022-08-04|UK          |premium     |11         |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0  |
    # |2022-08-02|DE          |premium     |41         |1659398400|[41]                        |41.0  |
    # |2022-08-05|DE          |premium     |23         |1659657600|[23, 41, 99]                |41.0  |
    # |2022-08-05|DE          |premium     |99         |1659657600|[23, 41, 99]                |41.0  |
    # |2022-08-01|DE          |free        |64         |1659312000|[64]                        |64.0  |
    # |2022-08-02|DE          |free        |29         |1659398400|[12, 29, 64]                |29.0  |
    # |2022-08-02|DE          |free        |12         |1659398400|[12, 29, 64]                |29.0  |
    # |2022-08-04|DE          |free        |6          |1659571200|[6, 12, 25, 29, 64, 86]     |27.0  |
    # |2022-08-04|DE          |free        |86         |1659571200|[6, 12, 25, 29, 64, 86]     |27.0  |
    # |2022-08-04|DE          |free        |25         |1659571200|[6, 12, 25, 29, 64, 86]     |27.0  |
    # |2022-08-05|DE          |free        |95         |1659657600|[6, 12, 25, 29, 86, 95]     |27.0  |
    # |2022-08-01|FR          |free        |93         |1659312000|[93]                        |93.0  |
    # |2022-08-02|FR          |free        |34         |1659398400|[34, 93]                    |63.5  |
    # |2022-08-03|FR          |free        |55         |1659484800|[34, 55, 81, 93]            |68.0  |
    # |2022-08-03|FR          |free        |81         |1659484800|[34, 55, 81, 93]            |68.0  |
    # |2022-08-04|FR          |free        |33         |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0  |
    # |2022-08-04|FR          |free        |16         |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0  |
    # |2022-08-04|FR          |free        |48         |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0  |
    # |2022-08-05|FR          |free        |68         |1659657600|[16, 33, 34, 48, 55, 68, 81]|48.0  |
    # |2022-08-02|FR          |premium     |31         |1659398400|[31]                        |31.0  |
    # |2022-08-04|FR          |premium     |67         |1659571200|[31, 67]                    |49.0  |
    # +----------+------------+------------+-----------+----------+----------------------------+------+
    

    Coming to the median calculation, if the array of the values has even number of elements, the average of the middle 2 elements should be the resulting median.