I have a DataFrame which looks roughly like this:
Material | Component | BatchSize | RequiredQuantity |
---|---|---|---|
A | A1 | 1300 | 1.0 |
A | A2 | 1300 | 0.056 |
A | A3 | 1300 | 2.78 |
A | B | 1300 | 1300.5 |
B | B1 | 1000 | 1007 |
B | B2 | 1000 | 3.5 |
B | C | 1000 | 9 |
C | C1 | 800 | 806.4 |
For each material, I need to loop through the components and recurse down to the lowest level components while adding a row for each one and performing a calculation to normalize the RequiredQuantity for the new rows: RequiredQuantity / BatchSize * Parent RequiredQuantity
. The resulting DataFrame should look like this:
Material | Component | BatchSize | RequiredQuantity |
---|---|---|---|
A | A1 | 1300 | 1.0 |
A | A2 | 1300 | 0.056 |
A | A3 | 1300 | 2.78 |
A | B | 1300 | 1300.5 |
A | B1 | 1300 | 1309.6035 |
A | B2 | 1300 | 4.55175 |
A | C | 1300 | 11.7045 |
A | C1 | 1300 | 11.798136 |
B | B1 | 1000 | 1007 |
B | B2 | 1000 | 3.5 |
B | C | 1000 | 9 |
B | C1 | 1000 | 9.072 |
C | C1 | 800 | 806.4 |
I tried writing a recursive function, which does work, but is extremely slow, taking roughly 5 minutes per material. This would be fine for a small table, but in our case we have almost 5000 different materials, each of which has roughly 10 different components, so it would take weeks to get through it all. I'm hoping there's a better way to handle this.
Here's the PySpark code I wrote:
def recurse_components(df, material):
if df.isEmpty():
return df
filtered_material = df.where(F.col("Material") == material)
batch_size = filtered_material.select("BatchSize").first()["BatchSize"]
component_list = (
filtered_material.select("Component").rdd.flatMap(lambda x: x).collect()
)
for component in component_list:
component_table = df.where(F.col("Material") == component)
if not component_table.isEmpty():
required_quantity = (
filtered_material.where(F.col("Component") == component)
.select("RequiredQuantity")
.first()["RequiredQuantity"]
)
recursive_call = recurse_components(df, component).withColumns(
{
"Material": F.lit(material),
"RequiredQuantity": F.col("RequiredQuantity")
* required_quantity
/ F.col("BatchSize"),
"BatchSize": F.lit(batch_size),
}
)
filtered_material = filtered_material.union(recursive_call)
return filtered_material
material_list = df.select("Material").distinct().rdd.flatMap(lambda x: x).collect()
extended_df = spark.createDataFrame([], df.schema)
for material in material_list:
extended_df = extended_df.union(recurse_components(df, material))
Any help would be greatly appreciated.
As suggested by Emma, GraphFrames worked great. You first have to define the nodes (which in my case was the list of distinct materials) and the edges (everything else), then create the graph, iterate through the levels of the graph, and union all the results together. Here's the code:
from pyspark.sql import functions as F, types as T, SparkSession
from graphframes import GraphFrame
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
[
("A", "A1", 1300, 1.0),
("A", "A2", 1300, 0.056),
("A", "A3", 1300, 2.78),
("A", "B", 1300, 1300.5),
("B", "B1", 1000, 1007.0),
("B", "B2", 1000, 3.5),
("B", "C", 1000, 9.0),
("C", "C1", 800, 806.4),
],
["Material", "Component", "BatchSize", "RequiredQuantity"],
)
nodes = (
df.select(F.col("Material").alias("id"))
.union(df.select(F.col("Component").alias("id")))
.distinct()
)
edges = (
df.select(
F.col("Material").alias("src"),
F.col("Component").alias("dst"),
"BatchSize",
"RequiredQuantity"
)
.union(
df.select(
"Component",
F.lit(None),
F.lit(None),
F.lit(None),
F.lit(None),
)
)
.distinct()
)
graph = GraphFrame(nodes, edges)
@F.udf(T.DoubleType())
def calculate_quantity(*edges: list[dict]) -> float:
result = edges[0]["RequiredQuantity"]
for e in edges[1:]:
result *= e["RequiredQuantity"] / e["BatchSize"]
return result
results = spark.createDataFrame([], df.schema)
i = 1
while True:
query = ";".join(f"(v{j})-[e{j}]->(v{j+1})" for j in range(i))
tmp = graph.find(query)
if tmp.isEmpty():
break
results = results.union(
tmp.select(
F.col("v0")["id"],
F.col(f"v{i}")["id"],
F.col("e0")["BatchSize"],
calculate_quantity(*(col for col in tmp.columns if col.startswith("e"))),
)
)
i += 1
This only took a little over 7 minutes to run for all ~5000 materials.
FYI: If you're doing this in a notebook in Fabric like I am, you'll need to add these at the top (or create an environment with these included):
%%configure -f
{
"conf": {
"spark.jars.packages": "graphframes:graphframes:0.8.4-spark3.5-s_2.12"
}
}
%pip install graphframes-py