I have a relatively shallow, directed, acyclic graph represented in GraphFrames (a large number of nodes, mainly on disjunct subgraphs). I want to propagate the id of the root nodes (nodes without incoming edges) to all nodes downstream. To achieve this, I chose the pregel algorithm. This process should converge once the passed messages don't change, however the process keeps going until the max iteration is reached.
This a model of the problem:
data = [
('v1', 'v1'),
('v3', 'v1'),
('v2', 'v1'),
('v4', 'v2'),
('v4', 'v5'),
('v5', 'v5'),
('v6', 'v4'),
]
df = spark.createDataFrame(data, ['variantId', 'explained']).persist()
# Create nodes:
nodes = (
df.select(
f.col('variantId').alias('id'),
f.when(f.col('variantId') == f.col('explained'), f.col('variantId')).alias('origin_root')
)
.distinct()
)
# Create edges:
edges = (
df
.filter(f.col('variantId')!=f.col('explained'))
.select(
f.col('variantId').alias('dst'),
f.col('explained').alias('src'),
f.lit('explains').alias('edgeType')
)
.distinct()
)
# Converting into a graphframe graph:
graph = GraphFrame(nodes, edges)
The graph will look like this:
I want to propagate
To do this I wrote the following function:
maxiter = 3
(
graph.pregel
.setMaxIter(maxiter)
# New column for the resolved roots:
.withVertexColumn(
"resolved_roots",
# The value is initialized by the original root value:
f.when(
f.col('origin_root').isNotNull(),
f.array(f.col('origin_root'))
).otherwise(f.array()),
# When new value arrives to the node, it gets merged with the existing list:
f.when(
Pregel.msg().isNotNull(),
f.array_union(Pregel.msg(), f.col('resolved_roots'))
).otherwise(f.col("resolved_roots"))
)
# We need to reinforce the message in both direction:
.sendMsgToDst(Pregel.src("resolved_roots"))
# Once the message is delivered it is updated with the existing list of roots at the node:
.aggMsgs(f.flatten(f.collect_list(Pregel.msg())))
.run()
.orderBy( 'id')
.show()
)
It returns:
+---+-----------+--------------+
| id|origin_root|resolved_roots|
+---+-----------+--------------+
| v1| v1| [v1]|
| v2| null| [v1]|
| v3| null| [v1]|
| v4| null| [v1, v5]|
| v5| v5| [v5]|
| v6| null| [v1, v5]|
+---+-----------+--------------+
Although all the nodes now have root information, which stays the same, if we increase the max iteration number to 100, the process just keeps going.
The questions:
Any helpful comment is highly appreciated, I'm absolutely new to graphs.
OSS GraphFrames does not take into consideration active message count and hence just depends on number of iterations to exit.
The code there looks like while (iteration <= maxIter)
There is GraphFrames library pre-installed with Databricks ML runtimes which is not open source I guess and it probably follows the same pattern.
If you need proper exit based on active messages, you have to use Spark Graphx Scala API as of now.
Scala implementation has some logic to detect number of active messages and it will exit if no new active messages are generated.
The code there looks like this : while (isActiveMessagesNonEmpty && i < maxIterations)
I have a medium blog post explaining Pregel in Scala with some examples somehow similar to the problem in this thread.
https://towardsdatascience.com/spark-graphx-pregel-its-not-so-complex-as-it-sounds-d196da246c73