I have the following data frame in pyspark:
date | user_country | account_type | num_listens |
---|---|---|---|
2022-08-01 | UK | premium | 32 |
2022-08-01 | DE | free | 64 |
2022-08-01 | FR | free | 93 |
2022-08-01 | UK | free | 51 |
2022-08-02 | UK | premium | 26 |
2022-08-02 | FR | free | 34 |
2022-08-02 | DE | free | 29 |
2022-08-02 | DE | premium | 41 |
2022-08-02 | DE | free | 12 |
2022-08-02 | FR | premium | 31 |
2022-08-03 | FR | free | 55 |
2022-08-03 | UK | premium | 38 |
2022-08-03 | UK | premium | 51 |
2022-08-03 | FR | free | 81 |
2022-08-04 | DE | free | 6 |
2022-08-04 | UK | premium | 97 |
2022-08-04 | FR | free | 33 |
2022-08-04 | UK | premium | 41 |
2022-08-04 | FR | premium | 67 |
2022-08-04 | DE | free | 86 |
2022-08-04 | DE | free | 25 |
2022-08-04 | FR | free | 16 |
2022-08-04 | FR | free | 48 |
2022-08-04 | UK | premium | 11 |
2022-08-04 | UK | free | 24 |
2022-08-05 | DE | free | 95 |
2022-08-05 | FR | free | 68 |
2022-08-05 | DE | premium | 23 |
2022-08-05 | UK | free | 79 |
2022-08-05 | UK | free | 41 |
2022-08-05 | DE | premium | 99 |
columns = ["date", "user_country","account_type", "num_listens"]
data = [("2022-08-01", "UK", "premium", "32"),
("2022-08-01", "DE", "free", "64"),
("2022-08-01", "FR", "free", "93"),
("2022-08-01", "UK", "free", "51"),
("2022-08-02", "UK", "premium", "26"),
("2022-08-02", "FR", "free", "34"),
("2022-08-02", "DE", "free", "29"),
("2022-08-02", "DE", "premium", "41"),
("2022-08-02", "DE", "free", "12"),
("2022-08-02", "FR", "premium", "31"),
("2022-08-03", "FR", "free", "55"),
("2022-08-03", "UK", "premium", "38"),
("2022-08-03", "UK", "premium", "51"),
("2022-08-03", "FR", "free", "81"),
("2022-08-04", "DE", "free", "6"),
("2022-08-04", "UK", "premium", "97"),
("2022-08-04", "FR", "free", "33"),
("2022-08-04", "UK", "premium", "41"),
("2022-08-04", "FR", "premium", "67"),
("2022-08-04", "DE", "free", "86"),
("2022-08-04", "DE", "free", "25"),
("2022-08-04", "FR", "free", "16"),
("2022-08-04", "FR", "free", "48"),
("2022-08-04", "UK", "premium", "11"),
("2022-08-04", "UK", "free", "24"),
("2022-08-05", "DE", "free", "95"),
("2022-08-05", "FR", "free", "68"),
("2022-08-05", "DE", "premium", "23"),
("2022-08-05", "UK", "free", "79"),
("2022-08-05", "UK", "free", "41"),
("2022-08-05", "DE", "premium", "99")
]
I'm trying to group this data by user_country, account_type and num_listens, always calculating the median value for each group. On top of this I would like to use a sliding time window to restrict the data I use for each aggregation. For example, when calculating the median value on 2022-08-04, I would only like to use data from the ten dates prior.
The resulting table should look as follows:
snapshot_date | user_country | account_type | median |
---|---|---|---|
2022-08-06 | UK | premium | 38 |
2022-08-06 | DE | free | 29 |
2022-08-06 | FR | free | 52 |
2022-08-06 | UK | free | 46 |
2022-08-06 | DE | premium | 41 |
2022-08-06 | FR | premium | 49 |
2022-08-05 | UK | premium | 38 |
2022-08-05 | DE | free | 27 |
2022-08-05 | FR | free | 48 |
2022-08-05 | UK | free | 38 |
2022-08-05 | DE | premium | 41 |
2022-08-05 | FR | premium | 49 |
2022-08-04 | UK | premium | 35 |
2022-08-04 | DE | free | 29 |
2022-08-04 | FR | free | 68 |
2022-08-04 | UK | free | 51 |
2022-08-04 | DE | premium | 41 |
2022-08-04 | FR | premium | 31 |
2022-08-03 | UK | premium | 29 |
2022-08-03 | DE | free | 29 |
2022-08-03 | FR | free | 64 |
2022-08-03 | UK | free | 51 |
2022-08-03 | DE | premium | 41 |
2022-08-03 | FR | premium | 31 |
2022-08-02 | UK | premium | 32 |
2022-08-02 | DE | free | 64 |
2022-08-02 | FR | free | 93 |
2022-08-02 | UK | free | 51 |
The value in the first row would be the median number of listens for all UK users with the premium account, using data from the previous 10 days (I only included a small sample of 5 days so in this specific case there would not be the full desired rang of 10 days available).
Any help on how this can be achieved in pyspark would be much appreciated. I've been fiddling around with combining a group by with a window function but have been unable to get the desired result.
You can collect the values in an array, and then apply the median logic on that.
For simplicity, I'll calculate the median of a window of 4 dates using your sample data. This is considering you don't want a continuity in the dates, i.e. previous 3 dates may or may not be in sequence ([2022-01-01, 2022-01-03, 2022-01-04, 2022-01-04]
is also acceptable).
data_sdf. \
withColumn('num_listens_arr',
func.array_sort(func.collect_list('num_listens').
over(wd.partitionBy('user_country', 'account_type').orderBy('date').rowsBetween(-3, 0))
)
). \
withColumn('median',
func.when(func.size('num_listens_arr') % 2 == 0,
func.expr('(num_listens_arr[int(size(num_listens_arr) / 2)-1] + num_listens_arr[int(size(num_listens_arr) / 2)]) / 2').cast('double')
).
otherwise(func.expr('num_listens_arr[int(size(num_listens_arr) / 2)]').cast('double'))
). \
show(data_sdf.count())
# +----------+------------+------------+-----------+----------------+------+
# | date|user_country|account_type|num_listens| num_listens_arr|median|
# +----------+------------+------------+-----------+----------------+------+
# |2022-08-01| UK| free| 51| [51]| 51.0|
# |2022-08-04| UK| free| 24| [24, 51]| 37.5|
# |2022-08-05| UK| free| 79| [24, 51, 79]| 51.0|
# |2022-08-05| UK| free| 41|[24, 41, 51, 79]| 46.0|
# |2022-08-01| UK| premium| 32| [32]| 32.0|
# |2022-08-02| UK| premium| 26| [26, 32]| 29.0|
# |2022-08-03| UK| premium| 38| [26, 32, 38]| 32.0|
# |2022-08-03| UK| premium| 51|[26, 32, 38, 51]| 35.0|
# |2022-08-04| UK| premium| 97|[26, 38, 51, 97]| 44.5|
# |2022-08-04| UK| premium| 41|[38, 41, 51, 97]| 46.0|
# |2022-08-04| UK| premium| 11|[11, 41, 51, 97]| 46.0|
# |2022-08-02| DE| premium| 41| [41]| 41.0|
# |2022-08-05| DE| premium| 23| [23, 41]| 32.0|
# |2022-08-05| DE| premium| 99| [23, 41, 99]| 41.0|
# |2022-08-01| DE| free| 64| [64]| 64.0|
# |2022-08-02| DE| free| 29| [29, 64]| 46.5|
# |2022-08-02| DE| free| 12| [12, 29, 64]| 29.0|
# |2022-08-04| DE| free| 6| [6, 12, 29, 64]| 20.5|
# |2022-08-04| DE| free| 86| [6, 12, 29, 86]| 20.5|
# |2022-08-04| DE| free| 25| [6, 12, 25, 86]| 18.5|
# |2022-08-05| DE| free| 95| [6, 25, 86, 95]| 55.5|
# |2022-08-01| FR| free| 93| [93]| 93.0|
# |2022-08-02| FR| free| 34| [34, 93]| 63.5|
# |2022-08-03| FR| free| 55| [34, 55, 93]| 55.0|
# |2022-08-03| FR| free| 81|[34, 55, 81, 93]| 68.0|
# |2022-08-04| FR| free| 33|[33, 34, 55, 81]| 44.5|
# |2022-08-04| FR| free| 16|[16, 33, 55, 81]| 44.0|
# |2022-08-04| FR| free| 48|[16, 33, 48, 81]| 40.5|
# |2022-08-05| FR| free| 68|[16, 33, 48, 68]| 40.5|
# |2022-08-02| FR| premium| 31| [31]| 31.0|
# |2022-08-04| FR| premium| 67| [31, 67]| 49.0|
# +----------+------------+------------+-----------+----------------+------+
If you do want to maintain the sequence of dates, you can use a rangeBetween()
.
data_sdf. \
withColumn('dt_long', func.col('date').cast('timestamp').cast('long')). \
withColumn('num_listens_arr',
func.array_sort(func.collect_list('num_listens').
over(wd.partitionBy('user_country', 'account_type').orderBy('dt_long').rangeBetween(-3*24*60*60, 0))
)
). \
withColumn('median',
func.when(func.size('num_listens_arr') % 2 == 0,
func.expr('(num_listens_arr[int(size(num_listens_arr) / 2)-1] + num_listens_arr[int(size(num_listens_arr) / 2)]) / 2').cast('double')
).
otherwise(func.expr('num_listens_arr[int(size(num_listens_arr) / 2)]').cast('double'))
). \
show(data_sdf.count(), truncate=False)
# +----------+------------+------------+-----------+----------+----------------------------+------+
# |date |user_country|account_type|num_listens|dt_long |num_listens_arr |median|
# +----------+------------+------------+-----------+----------+----------------------------+------+
# |2022-08-01|UK |free |51 |1659312000|[51] |51.0 |
# |2022-08-04|UK |free |24 |1659571200|[24, 51] |37.5 |
# |2022-08-05|UK |free |79 |1659657600|[24, 41, 79] |41.0 |
# |2022-08-05|UK |free |41 |1659657600|[24, 41, 79] |41.0 |
# |2022-08-01|UK |premium |32 |1659312000|[32] |32.0 |
# |2022-08-02|UK |premium |26 |1659398400|[26, 32] |29.0 |
# |2022-08-03|UK |premium |38 |1659484800|[26, 32, 38, 51] |35.0 |
# |2022-08-03|UK |premium |51 |1659484800|[26, 32, 38, 51] |35.0 |
# |2022-08-04|UK |premium |97 |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0 |
# |2022-08-04|UK |premium |41 |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0 |
# |2022-08-04|UK |premium |11 |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0 |
# |2022-08-02|DE |premium |41 |1659398400|[41] |41.0 |
# |2022-08-05|DE |premium |23 |1659657600|[23, 41, 99] |41.0 |
# |2022-08-05|DE |premium |99 |1659657600|[23, 41, 99] |41.0 |
# |2022-08-01|DE |free |64 |1659312000|[64] |64.0 |
# |2022-08-02|DE |free |29 |1659398400|[12, 29, 64] |29.0 |
# |2022-08-02|DE |free |12 |1659398400|[12, 29, 64] |29.0 |
# |2022-08-04|DE |free |6 |1659571200|[6, 12, 25, 29, 64, 86] |27.0 |
# |2022-08-04|DE |free |86 |1659571200|[6, 12, 25, 29, 64, 86] |27.0 |
# |2022-08-04|DE |free |25 |1659571200|[6, 12, 25, 29, 64, 86] |27.0 |
# |2022-08-05|DE |free |95 |1659657600|[6, 12, 25, 29, 86, 95] |27.0 |
# |2022-08-01|FR |free |93 |1659312000|[93] |93.0 |
# |2022-08-02|FR |free |34 |1659398400|[34, 93] |63.5 |
# |2022-08-03|FR |free |55 |1659484800|[34, 55, 81, 93] |68.0 |
# |2022-08-03|FR |free |81 |1659484800|[34, 55, 81, 93] |68.0 |
# |2022-08-04|FR |free |33 |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0 |
# |2022-08-04|FR |free |16 |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0 |
# |2022-08-04|FR |free |48 |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0 |
# |2022-08-05|FR |free |68 |1659657600|[16, 33, 34, 48, 55, 68, 81]|48.0 |
# |2022-08-02|FR |premium |31 |1659398400|[31] |31.0 |
# |2022-08-04|FR |premium |67 |1659571200|[31, 67] |49.0 |
# +----------+------------+------------+-----------+----------+----------------------------+------+
Coming to the median calculation, if the array of the values has even number of elements, the average of the middle 2 elements should be the resulting median.
size
of the array (number of elements)
(size/2)-1
th and (size/2)
th elements - e.g., if size is 6, arr[2]
element and arr[3]
element