In Spark, I have a large list (millions) of elements that contain items associated with each other. Examples:
1: ("A", "C", "D")
# Each of the items in this array is associated with any other element in the array, so A and C are associated, A, and D are associated, and C and D are associated.
2: ("F", "H", "I", "P")
3: ("H", "I", "D")
4: ("X", "Y", "Z")
I want to perform an operation to combine the associations where there are associations that go across the lists. In the example above, we can see that all the items of the first three lines are associated with each other (line 1 and line 2 should be combined because according line 3 D and I are associated). Therefore, the output should be:
("A", "C", "D", "F", "H", "I", "P")
("X", "Y", "Z")
What type of transformations in Spark can I use to perform this operation? I looked like various ways of grouping the data, but haven't found an obvious way to combine list elements if they share common elements.
Thank you!
As a couple of users have already stated, this can be seen as a graph problem, where you want to find the connected components in a graph.
As you are using spark, I think is a nice opportunity to show how to use graphx in python. To run this example you will need pyspark and graphframes python packages.
from pyspark.sql import SparkSession
from graphframes import GraphFrame
from pyspark.sql import functions as f
spark = (
SparkSession.builder.appName("test")
.config("spark.jars.packages", "graphframes:graphframes:0.8.2-spark3.2-s_2.12")
.getOrCreate()
)
# graphframe requires defining a checkpoint dir.
spark.sparkContext.setCheckpointDir("/tmp/checkpoint")
# lets create a sample dataframe
df = spark.createDataFrame(
[
(1, ["A", "C", "D"]),
(2, ["F", "H", "I", "P"]),
(3, ["H", "I", "D"]),
(4, ["X", "Y", "Z"]),
],
["id", "values"],
)
# We can use the explode function to explode the lists in new rows having a list of (id, node)
df = df.withColumn("node", f.explode("values"))
df.createOrReplaceTempView("temp_table")
# Then we can join the table with itself to generate an edge table with source and destination nodes.
edge_table = spark.sql(
"""
SELECT
distinct a.node as src, b.node as dst
FROM
temp_table a join temp_table b
ON a.id=b.id AND a.node != b.node
"""
)
# Now we define our graph by using an edge table (a table with the node ids)
# and our edge table
# then we use the connectedComponents method to find the components
cc_df = GraphFrame(
df.selectExpr("node as id").drop_duplicates(), edge_table
).connectedComponents()
# The cc_df dataframe will have two columns, the node id and the connected component.
# To get the desired result we can group by the component and create a list
cc_df.groupBy("component").agg(f.collect_list("id")).show(truncate=False)
The output you will get looks like this:
You can install the dependencies by using:
pip install -q pyspark==3.2 graphframes