pythonpyspark

Find tree hierachy in group and collect in a list - PySpark


In the data below, for each id2, I want to collect a list of the id1 that is above them in hierarchy/level.

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("group_id", StringType(), False),
    StructField("level", IntegerType(), False),
    StructField("id1", IntegerType(), False),
    StructField("id2", IntegerType(), False)
])

# Feature values
levels = [1, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]

id1_values = [0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867, 662867,  662867, 662867, 662867]

id2_values = [200001, 677555, 605026, 662867, 676423,  659933, 660206, 675767, 681116, 913248, 
    910758, 913773, 698738, 910387, 910758, 910387, 910113, 910657]

data = zip(['A'] * len(levels), levels, id1_values, id2_values)

# Create DataFrame
data = spark.createDataFrame(data, schema)

This can be done like this, using a window function and collect_list.

window = Window.partitionBy('group_id').orderBy('level').rowsBetween(Window.unboundedPreceding, Window.currentRow)

data.withColumn("list_id1", F.collect_list("id1").over(window)).show(truncate=False)
Output: 
+--------+-----+------+------+-------------------------------------------------------------------------------------------------------------------------------------------+
|group_id|level|id1   |id2   |list_id1                                                                                                                                   |
+--------+-----+------+------+-------------------------------------------------------------------------------------------------------------------------------------------+
|A       |1    |0     |200001|[0]                                                                                                                                        |
|A       |2    |200001|677555|[0, 200001]                                                                                                                                |
|A       |3    |677555|605026|[0, 200001, 677555]                                                                                                                        |
|A       |3    |677555|662867|[0, 200001, 677555, 677555]                                                                                                                |
|A       |3    |677555|676423|[0, 200001, 677555, 677555, 677555]                                                                                                        |
|A       |3    |677555|659933|[0, 200001, 677555, 677555, 677555, 677555]                                                                                                |
|A       |3    |677555|660206|[0, 200001, 677555, 677555, 677555, 677555, 677555]                                                                                        |
|A       |4    |605026|675767|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026]                                                                                |
|A       |4    |605026|681116|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026]                                                                        |
|A       |4    |605026|913248|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026]                                                                |
|A       |4    |605026|910758|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026]                                                        |
|A       |4    |605026|913773|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026]                                                |
|A       |4    |605026|698738|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026]                                        |
|A       |4    |662867|910387|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867]                                |
|A       |4    |662867|910758|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867, 662867]                        |
|A       |4    |662867|910387|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867, 662867, 662867]                |
|A       |4    |662867|910113|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867, 662867, 662867, 662867]        |
|A       |4    |662867|910657|[0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867, 662867, 662867, 662867, 662867]|
+--------+-----+------+------+-------------------------------------------------------------------------------------------------------------------------------------------+

In some cases there are several id1s with the same level. I want the collect_list to take this into account.

As an example, on level 4 we have two unique id1s, 605026 and 662867.For id2 910387, that corresponds to id1 662867 on level 4. I don't want to include 605026 in the list.

The list I want to collect should only include one id1 per level, capturing a tree path up to level 1.

For id2: 910657 that list should be [662867,677555, 200001, 0]

How can this be achieved using PySpark API?


Solution

  • First off, I would suggest renaming the columns of your table to make id1 as parent and id2 as node. Your table is essentially capturing parent-child relationship between nodes. Your goal is to construct full paths from each node.

    This is a computational expensive process because it requires repeated joins all all way up to the top of path. Worst case, it would take as many joins as the longest path.

    Here is what I tried

    from pyspark.sql import SparkSession
    from pyspark.sql import functions as F
    from pyspark.sql.window import Window
    
    spark = SparkSession.builder.appName("HierarchicalPath").getOrCreate()
    
    data = [
        ("A", 1, 0, 200001),
        ("A", 2, 200001, 677555),
        ("A", 3, 677555, 605026),
        ("A", 3, 677555, 662867),
        ("A", 3, 677555, 676423),
        ("A", 3, 677555, 659933),
        ("A", 3, 677555, 660206),
        ("A", 4, 605026, 675767),
        ("A", 4, 605026, 681116),
        ("A", 4, 605026, 913248),
        ("A", 4, 605026, 910758),
        ("A", 4, 605026, 913773),
        ("A", 4, 605026, 698738),
        ("A", 4, 662867, 910387),
        ("A", 4, 662867, 910758),
        ("A", 4, 662867, 910387),
        ("A", 4, 662867, 910113),
        ("A", 4, 662867, 910657)
    ]
    columns = ["group_id", "level", "parent", "node"]
    df = spark.createDataFrame(data, columns)
    
    df = df.withColumn("path", F.array("parent"))
    df.show(truncate=False)
    
    # Iteratively build the full path
    max_level = df.select(F.max("level")).collect()[0][0]
    current_level = max_level
    
    # Keep a copy of the original df to preserve child-parent relationships
    original_df = df
    
    while current_level > 1:
        # Repeatedly join to get the parent information from other original df
        # and overwrite the "growing" `df`
        joined_df = df.alias("child").join(
            original_df.alias("parent"),
            F.col("child.parent") == F.col("parent.node"),
            "left" # Left join because some paths are shorter than others
        ).select(
            F.col("child.group_id"),
            F.col("child.level"),
            F.col("parent.parent").alias("parent"),
            F.col("child.node"),
            # Append the latest parent to the path only if it's not null
            F.expr("CASE WHEN parent.parent IS NOT NULL THEN array_union(child.path, array(parent.parent)) ELSE child.path END").alias("path")
        )
        
        df = joined_df
        
        current_level -= 1
    
    df.show(truncate=False)
    
    # Some massaging in the end to produce accurate results: add the node itself
    # to the path and reverse the list
    result_df = df.select(
        "node",
        F.expr("array_union(reverse(path), array(node))").alias("full_path")
    ).orderBy("level")
    
    result_df.show(truncate=False)
    

    Gives me this as output

    
    +------+-----------------------------------+
    |node  |full_path                          |
    +------+-----------------------------------+
    |200001|[0, 200001]                        |
    |677555|[0, 200001, 677555]                |
    |605026|[0, 200001, 677555, 605026]        |
    |662867|[0, 200001, 677555, 662867]        |
    |676423|[0, 200001, 677555, 676423]        |
    |659933|[0, 200001, 677555, 659933]        |
    |660206|[0, 200001, 677555, 660206]        |
    |675767|[0, 200001, 677555, 605026, 675767]|
    |681116|[0, 200001, 677555, 605026, 681116]|
    |910387|[0, 200001, 677555, 662867, 910387]|
    |910758|[0, 200001, 677555, 662867, 910758]|
    |910387|[0, 200001, 677555, 662867, 910387]|
    |910113|[0, 200001, 677555, 662867, 910113]|
    |910657|[0, 200001, 677555, 662867, 910657]|
    |913248|[0, 200001, 677555, 605026, 913248]|
    |910758|[0, 200001, 677555, 605026, 910758]|
    |913773|[0, 200001, 677555, 605026, 913773]|
    |698738|[0, 200001, 677555, 605026, 698738]|
    +------+-----------------------------------+