pythonapache-sparkpysparkmapreduce

pyspark RDD count nodes in a DAG


I have RDD which shows up as

["2\t{'3': 1}",
 "3\t{'2': 2}",
 "4\t{'1': 1, '2': 1}",
 "5\t{'4': 3, '2': 1, '6': 1}",
 "6\t{'2': 1, '5': 2}",
 "7\t{'2': 1, '5': 1}",
 "8\t{'2': 1, '5': 1}",
 "9\t{'2': 1, '5': 1}",
 "10\t{'5': 1}",
 "11\t{'5': 2}"]

I could split it up and am able to count the nodes before the '\t' or i can write a function to count the nodes on the right. This is a weighet DAG. If i count by hand, I see there are 11 nodes. but am unable to figure out how to bring the node 1 on right side into the nodes before I do distinct and count. My code is

`import ast
def break_nodes(line):
    data_dict = ast.literal_eval(line)
    
    # Iterate through the dictionary items and print them
    for key, value in data_dict.items():
        print(f'key {key} val {value}')
        yield (int(key))
        
    
nodeIDs = dataRDD.map(lambda line: line.split('\t')) \
                    .flatMap(lambda x: break_nodes(x[1])) \
                    .distinct()`

This just counts the node from the right of \t. I have code for left side which is very simple

`nodeIDs = dataRDD.flatMap(lambda line: line.split('\t')[0])
totalCount = nodeIDs.distinct().count()`

What modification can I do to the code to get all the nodes counted? My brain is fried trying so many ways

Appreciate the help


Solution

  • Let us use flatMap to find all the nodes in the rdd then use distinct to get the unique nodes

    import ast
    
    def find_all(r):
        x, y = r.split('\t')
        return [x, *ast.literal_eval(y).keys()]
    
    nodes = dataRDD.flatMap(find_all).distinct()
    

    nodes.collect()
    # ['4', '5', '10', '2', '1', '9', '3', '6', '7', '8', '11']
    
    nodes.count()
    # 11