pythonpyspark

How to efficiently recurse through PySpark DataFrame?


I have a DataFrame which looks roughly like this:

Material Component BatchSize RequiredQuantity
A A1 1300 1.0
A A2 1300 0.056
A A3 1300 2.78
A B 1300 1300.5
B B1 1000 1007
B B2 1000 3.5
B C 1000 9
C C1 800 806.4

For each material, I need to loop through the components and recurse down to the lowest level components while adding a row for each one and performing a calculation to normalize the RequiredQuantity for the new rows: RequiredQuantity / BatchSize * Parent RequiredQuantity. The resulting DataFrame should look like this:

Material Component BatchSize RequiredQuantity
A A1 1300 1.0
A A2 1300 0.056
A A3 1300 2.78
A B 1300 1300.5
A B1 1300 1309.6035
A B2 1300 4.55175
A C 1300 11.7045
A C1 1300 11.798136
B B1 1000 1007
B B2 1000 3.5
B C 1000 9
B C1 1000 9.072
C C1 800 806.4

I tried writing a recursive function, which does work, but is extremely slow, taking roughly 5 minutes per material. This would be fine for a small table, but in our case we have almost 5000 different materials, each of which has roughly 10 different components, so it would take weeks to get through it all. I'm hoping there's a better way to handle this.

Here's the PySpark code I wrote:

def recurse_components(df, material):
    if df.isEmpty():
        return df

    filtered_material = df.where(F.col("Material") == material)
    batch_size = filtered_material.select("BatchSize").first()["BatchSize"]

    component_list = (
        filtered_material.select("Component").rdd.flatMap(lambda x: x).collect()
    )

    for component in component_list:
        component_table = df.where(F.col("Material") == component)
        if not component_table.isEmpty():
            required_quantity = (
                filtered_material.where(F.col("Component") == component)
                .select("RequiredQuantity")
                .first()["RequiredQuantity"]
            )
            recursive_call = recurse_components(df, component).withColumns(
                {
                    "Material": F.lit(material),
                    "RequiredQuantity": F.col("RequiredQuantity")
                    * required_quantity
                    / F.col("BatchSize"),
                    "BatchSize": F.lit(batch_size),
                }
            )
            filtered_material = filtered_material.union(recursive_call)

    return filtered_material

material_list = df.select("Material").distinct().rdd.flatMap(lambda x: x).collect()

extended_df = spark.createDataFrame([], df.schema)
for material in material_list:
    extended_df = extended_df.union(recurse_components(df, material))

Any help would be greatly appreciated.


Solution

  • As suggested by Emma, GraphFrames worked great. You first have to define the nodes (which in my case was the list of distinct materials) and the edges (everything else), then create the graph, iterate through the levels of the graph, and union all the results together. Here's the code:

    from pyspark.sql import functions as F, types as T, SparkSession
    from graphframes import GraphFrame
    
    spark = SparkSession.builder.getOrCreate()
    
    df = spark.createDataFrame(
        [
            ("A", "A1", 1300, 1.0),
            ("A", "A2", 1300, 0.056),
            ("A", "A3", 1300, 2.78),
            ("A", "B", 1300, 1300.5),
            ("B", "B1", 1000, 1007.0),
            ("B", "B2", 1000, 3.5),
            ("B", "C", 1000, 9.0),
            ("C", "C1", 800, 806.4),
        ],
        ["Material", "Component", "BatchSize", "RequiredQuantity"],
    )
    
    nodes = (
        df.select(F.col("Material").alias("id"))
        .union(df.select(F.col("Component").alias("id")))
        .distinct()
    )
    
    edges = (
        df.select(
            F.col("Material").alias("src"),
            F.col("Component").alias("dst"),
            "BatchSize",
            "RequiredQuantity"
        )
        .union(
            df.select(
                "Component",
                F.lit(None),
                F.lit(None),
                F.lit(None),
                F.lit(None),
            )
        )
        .distinct()
    )
    
    graph = GraphFrame(nodes, edges)
    
    
    @F.udf(T.DoubleType())
    def calculate_quantity(*edges: list[dict]) -> float:
        result = edges[0]["RequiredQuantity"]
        for e in edges[1:]:
            result *= e["RequiredQuantity"] / e["BatchSize"]
        return result
    
    
    results = spark.createDataFrame([], df.schema)
    
    i = 1
    while True:
        query = ";".join(f"(v{j})-[e{j}]->(v{j+1})" for j in range(i))
        tmp = graph.find(query)
        if tmp.isEmpty():
            break
        results = results.union(
            tmp.select(
                F.col("v0")["id"],
                F.col(f"v{i}")["id"],
                F.col("e0")["BatchSize"],
                calculate_quantity(*(col for col in tmp.columns if col.startswith("e"))),
            )
        )
        i += 1
    

    This only took a little over 7 minutes to run for all ~5000 materials.

    FYI: If you're doing this in a notebook in Fabric like I am, you'll need to add these at the top (or create an environment with these included):

    %%configure -f
    {
        "conf": {
            "spark.jars.packages": "graphframes:graphframes:0.8.4-spark3.5-s_2.12"
        }
    }
    
    %pip install graphframes-py