trying to group the column values based on related records
partColumns = (["partnumber","colVal1","colVal2", "colVal3","colVal4","colVal5"])
partrelations = ([("part0","part1","", "","",""),
("part1","","part2", "","part4",""),
("part2","part3", "", "part5","part6","part7"),
("part10","part11","", "","",""),
("part11","part13","part21", "","",""),
("part13","part21","part18", "","part20",""),
])
df_part_groups = spark.createDataFrame(data=partrelations, schema = partColumns)
trying to get output as below -
edges = (df_part_groups
.withColumnRenamed("partnumber", "src")
.withColumnRenamed("colVal1", "dst")
)
vertices = (edges.select("src").distinct()
.union(edges.select("dst").distinct())
.withColumnRenamed("src", "id"))
#create a graph and find all connected components
g = G.GraphFrame(vertices, edges)
cc = g.connectedComponents()
display(df_part_groups
.join(cc.distinct(), df_part_groups.device == cc.id)
.orderBy("component", "partnumber", "colVal1"))
Above is what I am trying to put together
thanks for help!!
We can do a simple check using set intersection to solve the problem. (Not aware of GraphFrames :()
step 1: combine all parts in to a single array for each row
from pyspark.sql import functions as F
df_part_groups1= df_part_groups.withColumn('parts', F.array('partnumber', 'colVal1', 'colVal2', 'colVal3', 'colVal4', 'colVal5') )
step 2: get all_parts which is a list of lists of combined parts, since the group needs to be determined amongst various rows.
def clean_lists(plists):
return [ list(filter(None, pl)) for pl in plists]
all_parts = clean_lists((df_part_groups1.groupBy(F.lit(1)).agg(F.collect_list('parts').alias('parts')).collect())[0].parts)
step 3: get groups data using the collected all_parts
def part_of_existing_group(gps, pl):
for key in gps.keys():
if set(gps[key]) & set(pl):
gps[key] = list(set(gps[key] + pl))
return True
return False
def findGroups(plists):
groups = {}
index = 1
for pl in plists:
if len(groups.keys()) == 0 or (not part_of_existing_group(groups, pl)):
groups[f'G{index}'] = pl
index +=1
return groups
Step 4: Assign groups based on the groups map that you created.
groups = findGroups(all_parts)
@udf
def get_group_val(part):
for key in groups.keys():
if part in groups[key]:
return key
return -1
df_part_groups2 = df_part_groups1.withColumn('part', F.explode('parts')).dropDuplicates(['part']).where(~F.col('part').like('')).select('part', 'parts').withColumn('Group', get_group_val('part'))
df_part_groups2.show()
+------+--------------------+-----+
| part| parts|Group|
+------+--------------------+-----+
| part0|[part0, part1, , ...| G1|
| part1|[part0, part1, , ...| G1|
|part10|[part10, part11, ...| G2|
|part11|[part10, part11, ...| G2|
|part13|[part11, part13, ...| G2|
|part18|[part13, part21, ...| G2|
| part2|[part1, , part2, ...| G1|
|part20|[part13, part21, ...| G2|
|part21|[part11, part13, ...| G2|
| part3|[part2, part3, , ...| G1|
| part4|[part1, , part2, ...| G1|
| part5|[part2, part3, , ...| G1|
| part6|[part2, part3, , ...| G1|
| part7|[part2, part3, , ...| G1|
+------+--------------------+-----+