pythonpyspark

Collect list inside window function with condition, pyspark


I want to collect a list of all the values of id2 for each id1 that has the same or lower level within a group.

To achieve this I use a window function and collect_list function. However, I dont get the conditional part here. How can that be solved?


df = spark.createDataFrame([
    ("A", 0, "M1", "D1"),
    ("A", 1, "D1", "D2"),
    ("A", 2, "D2", "D3"),
    ("A", 3, "D3", "D4"),
    ("B", 0, "M2", "D5"),
    ("B", 1, "D4", "D6"),
    ("B", 2, "D5", "D7")
], ["group_id", "level", "id1", "id2"])



window = Window.partitionBy('group_id').orderBy('level').rowsBetween(
    Window.unboundedPreceding, Window.unboundedFollowing
)

df_with_list = df.withColumn(
    "list_lower_level",
    F.collect_list("id2").over(window)
)

df_with_list.show()

The output is this:

+--------+-----+---+---+----------------+
|group_id|level|id1|id2|list_lower_level|
+--------+-----+---+---+----------------+
|       A|    0| M1| D1|[D1, D2, D3, D4]|
|       A|    1| D1| D2|[D1, D2, D3, D4]|
|       A|    2| D2| D3|[D1, D2, D3, D4]|
|       A|    3| D3| D4|[D1, D2, D3, D4]|
|       B|    0| M2| D5|    [D5, D6, D7]|
|       B|    1| D4| D6|    [D5, D6, D7]|
|       B|    2| D5| D7|    [D5, D6, D7]|
+--------+-----+---+---+----------------+

However, I want to achive this:

+--------+-----+---+---+----------------+
|group_id|level|id1|id2|list_lower_level|
+--------+-----+---+---+----------------+
|       A|    0| M1| D1|[D1, D2, D3, D4]|
|       A|    1| D1| D2|[D2, D3, D4]|
|       A|    2| D2| D3|[D3, D4]|
|       A|    3| D3| D4|[D4]|
|       B|    0| M2| D5|    [D5, D6, D7]|
|       B|    1| D4| D6|    [D6, D7]|
|       B|    2| D5| D7|    [D7]|
+--------+-----+---+---+----------------+

Solution

  • in that case you don't need to look at the previous rows try this window:

    window = Window.partitionBy('group_id').orderBy('level').rowsBetween(
        0, Window.unboundedFollowing
    )