pysparksaswindowretain

Pyspark retain value across rows?


I have a problem which is naturally solved using a row-by-row SAS approach, but I'm stuck using Pyspark. I have a dataset of events for people ordered by time, for example:

test_df = pd.DataFrame({'event_list':[["H"], ["H"], ["H","F"], ["F"], ["F"], ["H"], ["W"], ["W"]], 'time_order':[1,2,3,4,5,6,7,8], 'person':[1,1,1,1,1,1,1,1]})
test_df = spark.createDataFrame(test_df)
test_df.show()

+----------+----------+------+
|event_list|time_order|person|
+----------+----------+------+
|       [H]|         1|     1|
|       [H]|         2|     1|
|    [H, F]|         3|     1|
|       [F]|         4|     1|
|       [F]|         5|     1|
|       [H]|         6|     1|
|       [W]|         7|     1|
|       [W]|         8|     1|
+----------+----------+------+

I want to group these events into episodes where all events following the initial event are part of the initial event list. Therefore in my test_df I would expect 3 episodes:

+----------+----------+------+-------+
|event_list|time_order|person|episode|
+----------+----------+------+-------+
|       [H]|         1|     1|      1|
|       [H]|         2|     1|      1|
|    [H, F]|         3|     1|      2|
|       [F]|         4|     1|      2|
|       [F]|         5|     1|      2|
|       [H]|         6|     1|      2|
|       [W]|         7|     1|      3|
|       [W]|         8|     1|      3|
+----------+----------+------+-------+

In SAS I would retain the prior row's value for event_list, and if the current event_list is contained in the prior event_list, I would retain the current event_list value rather than the prior event_list. E.g. my retained values would be [null, ["H"], ["H"], ["H","F"], ["H","F"], ["H","F"], ["W"]]. Then I can generate the episodes by tracking changes in the retained values.

In Pyspark I'm not sure how to retain information sequentially across row operations...is this even possible? My attempts using window functions (partitioning by person and ordering by time_order) have failed. How can I solve this problem in Pyspark?


Solution

  • If you are using spark version >= 2.4, use collect_list on event_list column over window, flatten them, remove duplicates using array_distinct and finally use size to count how many distinct events over time. It would be something like this :

    from pyspark.sql.functions import col, collect_list, flatten, array_distinct, size
    from pyspark.sql.window import Window
    
    w = Window.partitionBy('person').orderBy('time_order').rowsBetween(Window.unboundedPreceding, 0)
    
    test_df = test_df.withColumn('episode', size(array_distinct(flatten(collect_list(col('event_list')).over(w)))))
    test_df.show()
    
    +----------+----------+------+-------+
    |event_list|time_order|person|episode|
    +----------+----------+------+-------+
    |       [H]|         1|     1|      1|
    |       [H]|         2|     1|      1|
    |    [H, F]|         3|     1|      2|
    |       [F]|         4|     1|      2|
    |       [F]|         5|     1|      2|
    |       [H]|         6|     1|      2|
    |       [W]|         7|     1|      3|
    |       [W]|         8|     1|      3|
    +----------+----------+------+-------+