I want to collect a list of all the values of id2 for each id1 that has the same or lower level within a group.
To achieve this I use a window function and collect_list function. However, I dont get the conditional part here. How can that be solved?
df = spark.createDataFrame([
("A", 0, "M1", "D1"),
("A", 1, "D1", "D2"),
("A", 2, "D2", "D3"),
("A", 3, "D3", "D4"),
("B", 0, "M2", "D5"),
("B", 1, "D4", "D6"),
("B", 2, "D5", "D7")
], ["group_id", "level", "id1", "id2"])
window = Window.partitionBy('group_id').orderBy('level').rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
df_with_list = df.withColumn(
"list_lower_level",
F.collect_list("id2").over(window)
)
df_with_list.show()
The output is this:
+--------+-----+---+---+----------------+
|group_id|level|id1|id2|list_lower_level|
+--------+-----+---+---+----------------+
| A| 0| M1| D1|[D1, D2, D3, D4]|
| A| 1| D1| D2|[D1, D2, D3, D4]|
| A| 2| D2| D3|[D1, D2, D3, D4]|
| A| 3| D3| D4|[D1, D2, D3, D4]|
| B| 0| M2| D5| [D5, D6, D7]|
| B| 1| D4| D6| [D5, D6, D7]|
| B| 2| D5| D7| [D5, D6, D7]|
+--------+-----+---+---+----------------+
However, I want to achive this:
+--------+-----+---+---+----------------+
|group_id|level|id1|id2|list_lower_level|
+--------+-----+---+---+----------------+
| A| 0| M1| D1|[D1, D2, D3, D4]|
| A| 1| D1| D2|[D2, D3, D4]|
| A| 2| D2| D3|[D3, D4]|
| A| 3| D3| D4|[D4]|
| B| 0| M2| D5| [D5, D6, D7]|
| B| 1| D4| D6| [D6, D7]|
| B| 2| D5| D7| [D7]|
+--------+-----+---+---+----------------+
in that case you don't need to look at the previous rows try this window:
window = Window.partitionBy('group_id').orderBy('level').rowsBetween(
0, Window.unboundedFollowing
)