graph-theorydijkstracomparablepython-classweighted-graph

Dijkstra's weighted shortest path in Python


I'm trying to solve a question from PepCoding, Graph Foundation 1, Shortest Path In Weights, and replicating the solution from Java to Python.

Question:

1. You are given a graph and a source vertex. The vertices represent cities and the edges represent 
    distance in kms.
2. You are required to find the shortest path to each city (in terms of kms) from the source city along 
    with the total distance on path from source to destinations.

Note -> For output, check the sample output and question video.

Sample input:

7
9
0 1 10
1 2 10
2 3 10
0 3 40
3 4 2
4 5 3
5 6 3
4 6 8
2 5 5
0

Output:

0 via 0 @ 0
1 via 01 @ 10
2 via 012 @ 20
5 via 0125 @ 25
4 via 01254 @ 28
6 via 01256 @ 28
3 via 012543 @ 30

I implemented a list instead of PriorityQueue, and trying to sort the Pair element in the class using operator such as __lt__ and __ge__ . However, I was getting all the results correct except the shortest path between 0 and 3 is incorrect.

This is my output:

0  via  0 @ 0
1  via  01 @ 10
2  via  012 @ 20
5  via  0125 @ 25
4  via  01254 @ 28
6  via  01256 @ 28
3  via  0123 @ 30  <-- This differ should be: 3 via 012543 @ 30
Loop:  [(28, 6, '01256'), (30, 3, '0123'), (40, 3, '03'), (30, 3, '012543')]
Loop:  [(28, 6, '01256'), (30, 3, '0123'), (40, 3, '03'), (30, 3, '012543'), (36, 6, '012546')]
POP:  (28, 6, '01256')
6  via  01256 @ 28
POP:  (30, 3, '0123')  <-- *This is getting pop instead of, POP:  (30, 3, '012543')
3  via  0123 @ 30
POP:  (30, 3, '012543')
POP:  (36, 6, '012546')
POP:  (40, 3, '03')

This code in Java where compareTo is implemented, is making the difference.

static class Pair implements Comparable<Pair> {
      int v;
      String psf;
      int wsf;

      Pair(int v, String psf, int wsf){
         this.v = v;
         this.psf = psf;
         this.wsf = wsf;
      }

      public int compareTo(Pair o){
         return this.wsf - o.wsf;
      }
   }

Here is the code in Python:


class Edge:
    def __init__(self, src, nbr, wt):
        self.src = src
        self.nbr = nbr
        self.wt = wt

    def __repr__(self):
        return repr(self.__dict__)


class Pair:
    def __init__(self, wsf, v, psf):
        self.wsf = wsf
        self.v = v
        self.psf = psf

    def __repr__(self):
        return repr((self.wsf, self.v, self.psf))

    def __lt__(self, o):
        return self.wsf < o.wsf

    def __ge__(self, o):
        return len(self.psf) < len(o.psf)


def main():
    vtces = int(input())
    edges = int(input())
    graph = {}
    for i in range(vtces):
        graph[i] = []

    for i in range(edges):
        lines = input().split(" ")
        v1 = int(lines[0])
        v2 = int(lines[1])
        wt = int(lines[2])
        e1 = Edge(v1, v2, wt)
        e2 = Edge(v2, v1, wt)
        graph[e1.src].append(e1)
        graph[e2.src].append(e2)

    src = int(input())
    # print(type(graph))
    # for i in graph:
    #     for j in graph[i]:
    #         print(i, j)

    # Write your code here
    pq = []

    pq.append(Pair(0, src, str(src) + ""))
    # print("\nStart: ",pq)
    visited = [False] * vtces

    while len(pq) > 0:
        pq.sort()
        rem = pq.pop(0)
        # print("POP: ", rem)

        if visited[rem.v] == True:
            continue
        visited[rem.v] = True
        print(rem.v, " via ", rem.psf, "@", rem.wsf)

        for e in graph[rem.v]:
            if visited[e.nbr] == False:
                pq.append(Pair(rem.wsf + e.wt, e.nbr, str(rem.psf) + str(e.nbr)))
                # print("Loop: ",pq)

    # print(pq)


if __name__ == "__main__":
    main()

PS: Please 🙏 forgive me if I typed or unable to explain properly as I am new to this world and trying my best.


Solution

  • After trying multiple class comparator or Java's 'compareTo()' as mentioned here. Finally, I came across 'Python equivalent of Java's compareTo()' here, and did the following modification in the Pair class and it solved the path for 3 via 012543 @ 30 [Previous error was 30, 3, '0123'].

    class Pair():
        def __init__(self, wsf, v, psf):
            self.wsf = wsf
            self.v = v
            self.psf = psf
    
        def __repr__(self):
            return repr((self.wsf, self.v, self.psf))
    
        def __lt__(self, other):
            if self.wsf == other.wsf:
                return len(self.psf) > len(other.psf)
            return self.wsf < other.wsf
    
        def __gt__(self, other):
            return other.__lt__(self)
    
    

    I implemented the above using list, but during testing I tried heapq which didn't resolve much. Maybe I'm unsure about any better way.

    while len(pq) > 0:
            pq.sort()
            # print("sLoop: ", pq)
            rem = pq.pop(0)
            # print("POP: ", rem)
    
            if visited[rem.v] == True:
                continue
            visited[rem.v] = True
            print(rem.v, "via", rem.psf, "@", rem.wsf)
    
            for e in graph[rem.v]:
                if visited[e.nbr] == False:
                    pq.append(Pair(rem.wsf + e.wt, e.nbr, str(rem.psf) + str(e.nbr)))
                    # print("Loop: ", pq)
    
    

    Thus, welcoming any suggestions to improve the above code.

    Also, the below test case fails and I'm not sure what is the expected output/input in the test case (as multiple edges are possible):

    2 via 2 @ 0
    5 via 25 @ 5
    4 via 254 @ 8
    6 via 256 @ 8
    3 via 2543 @ 10
    1 via 21 @ 10
    0 via 210 @ 20
    
    

    But still I'm satisfied with my result. Thanks everyone.