javatreesetminimum-spanning-treeprims-algorithm

Java TreeSet Weird Behavior


I'm trying to implement Prim's Minimum Spanning Tree algorithm using Balanced BST instead of a Priority Queue. My implementation is in Java. And since Java already has a library implementation of Red-Black Tree in the form of TreeSet, I thought to use that instead of having my own custom implementation.

A typical implementation of Prim's using Min Priority Queue takes O(ElogE), and my intention behind this implementation was to reduce the time complexity to O(ElogV). I believe this can also be done using an Indexed Priority Queue (IPQ), but I went with BST version as there are library implementations of Self-Balancing BST available (Both in Java and in C++).

For the examples I have tested this implementation with, it seems to work fine, and is producing correct results. But when I did a deeper inspection to make sure that the TC was actually O(ElogV), I found that Java TreeSet is behaving weirdly for me for some reason.

Here is my implementation:

    package graph;

    import java.util.Comparator;
    import java.util.TreeSet;

    /**
     * Implementation of Prim's algorithm (eager version) to
     * find Minimum Spanning Tree using a self-balancing BST
     * Time Complexity: O(ElogV)
     * 
     * This implementation uses a self-balancing BST (specifically Red-Black Tree)
     * 
     * We can do eager Prim's implementation using an Indexed Priority Queue (IPQ) as well
     * 
     * Comparison: IPQ vs BST
     * To get next best edge in IPQ, we pop the min element from root, and 
     * then heapify the tree, which overall takes O(lgn). To get next best edge in 
     * BST, it would take O(lgn) as well, and then we’ll have to delete that entry 
     * which would take another O(lgn), but overall it is still O(lgn)
     * 
     * Insertion into both BST and IPQ takes O(lgn) anyway
     * 
     * Update in IPQ takes O(lgn). Update in BST as well can be done in 
     * O(lgn) [search the element in O(lgn) then delete that entry in another 
     * O(lgn) and then insert new entry with updated edge weight (and source node) 
     * in yet another O(lgn). Intotal, it takes 3*logn but overall still O(lgn)]
     *
     */
    public class PrimsMstUsingSelfBalancingBST extends Graph {

        private int n, m, edgeCount;
        private boolean[] visited;
        private Edge[] mst;
        private double cost;
        private TreeSet<Edge> sortedSet;

        public PrimsMstUsingSelfBalancingBST(int numberOfVertices) {
            super(numberOfVertices);
            n = numberOfVertices;
        }

        public Double mst(int s) {
            m = n - 1; // number of expected edges in mst
            edgeCount = 0;
            mst = new Edge[m];
            visited = new boolean[n];
            sortedSet = new TreeSet<>(getComparator());

            relaxEdgesAtNode(s);

            while (!sortedSet.isEmpty() && edgeCount != m) {
                System.out.println(sortedSet);
                // pollFirst() retrieves and removes smallest element from TreeSet
                Edge edge = sortedSet.pollFirst();
                int nodeIndex = edge.to;

                // skip edges pointing to already visited nodes
                if (visited[nodeIndex]) continue;

                mst[edgeCount] = edge;
                edgeCount++;
                cost += edge.wt;

                relaxEdgesAtNode(nodeIndex);
            }

            return (edgeCount == m) ? cost : null;
        }

        private void relaxEdgesAtNode(int index) {
            visited[index] = true;

            for (Edge edge : adjList.get(index))  {
                int to = edge.to;

                if (visited[to]) continue;

                if (!sortedSet.contains(edge)) {
                    sortedSet.add(edge);
                }
                else {
                    Edge existingEdge = search(edge);
                    if (existingEdge.wt > edge.wt) {
                        sortedSet.remove(existingEdge);
                        sortedSet.add(edge);
                    }
                }
            }
        }

        private Comparator<Edge> getComparator() {
            return new Comparator<Edge>() {
                @Override
                public int compare(Edge e1, Edge e2) {
                    // Java TreeSet is implemented in a way that it uses 
                    // Comparable/Comparator logics for all comparisons.

                    // i.e., it will use this comparator to do comparison 
                    // in contains() method.

                    // It will use this same comparator to do comparison 
                    // during remove() method.

                    // It will also use this same comparator, to perform 
                    // sorting during add() method.

                    // While looking up an edge from contains() or remove(), 
                    // we need to perform check based on destinationNodeIndex,

                    // But to keep elements in sorted order during add() operation
                    // we need to compare elements based on their edge weights

                    // For correct behavior of contains() and remove()
                    if (e1.to == e2.to) return 0;

                    // For correct sorting behavior
                    if (e1.wt > e2.wt) return 1;
                    else if (e1.wt < e2.wt) return -1;

                    // Return -1 or 1 to make sure that different edges with equal 
                    // weights are not ignored by TreeSet.add()
                    // this check can be included in either 'if' or 'else' part 
                    // above. Keeping this separate for readability.
                    return -1;
                }
            };
        }

        // O(log n) search in TreeSet
        private Edge search(Edge e) {
            Edge ceil  = sortedSet.ceiling(e); // smallest element >= e
            Edge floor = sortedSet.floor(e);   // largest element <= e

            return ceil.equals(floor) ? ceil : null; 
        }

        public static void main(String[] args) {
            example1();
        }

        private static void example1() {
            int n = 8;
            PrimsMstUsingSelfBalancingBST graph = 
                    new PrimsMstUsingSelfBalancingBST(n);

            graph.addEdge(0, 1, true, 10);
            graph.addEdge(0, 2, true, 1);
            graph.addEdge(0, 3, true, 4);
            graph.addEdge(2, 1, true, 3);
            graph.addEdge(2, 5, true, 8);
            graph.addEdge(2, 3, true, 2);
            graph.addEdge(3, 5, true, 2);
            graph.addEdge(3, 6, true, 7);
            graph.addEdge(5, 4, true, 1);
            graph.addEdge(5, 7, true, 9);
            graph.addEdge(5, 6, true, 6);
            graph.addEdge(4, 1, true, 0);
            graph.addEdge(4, 7, true, 8);
            graph.addEdge(6, 7, true, 12);

            int s = 0;
            Double cost = graph.mst(s);
            if (cost != null) {
                System.out.println(cost); // 20.0
                for (Edge e : graph.mst)
                    System.out.println(String.format("Used edge (%d, %d) with cost: %.2f", e.from, e.to, e.wt));
                /*
                 * Used edge (0, 2) with cost: 1.00
                 * Used edge (2, 3) with cost: 2.00
                 * Used edge (3, 5) with cost: 2.00
                 * Used edge (5, 4) with cost: 1.00
                 * Used edge (4, 1) with cost: 0.00
                 * Used edge (5, 6) with cost: 6.00
                 * Used edge (4, 7) with cost: 8.00
                 */
            }
            else {
                System.out.println("MST not found!");
            }
        }
    }

Below is the undirected weighted graph I'm testing this with (same example is used in code as well)

example image

The issue that I'm facing is that TreeSet seems to be adding duplicate entries as contains() method on it sometimes returns false even when a corresponding entry with same key (destination node of edges in this case) is already present.

Below is the output of above program:

[{from=0, to=2, weight=1.00}, {from=0, to=3, weight=4.00}, {from=0, to=1, weight=10.00}]
[{from=2, to=3, weight=2.00}, {from=2, to=1, weight=3.00}, {from=2, to=5, weight=8.00}, {from=0, to=1, weight=10.00}]
[{from=3, to=5, weight=2.00}, {from=2, to=1, weight=3.00}, {from=3, to=6, weight=7.00}, {from=0, to=1, weight=10.00}]
[{from=5, to=4, weight=1.00}, {from=2, to=1, weight=3.00}, {from=5, to=6, weight=6.00}, {from=5, to=7, weight=9.00}, {from=0, to=1, weight=10.00}]
[{from=4, to=1, weight=0.00}, {from=5, to=6, weight=6.00}, {from=4, to=7, weight=8.00}, {from=0, to=1, weight=10.00}]
[{from=5, to=6, weight=6.00}, {from=4, to=7, weight=8.00}, {from=0, to=1, weight=10.00}]
[{from=4, to=7, weight=8.00}, {from=0, to=1, weight=10.00}]
20.0
Used edge (0, 2) with cost: 1.00
Used edge (2, 3) with cost: 2.00
Used edge (3, 5) with cost: 2.00
Used edge (5, 4) with cost: 1.00
Used edge (4, 1) with cost: 0.00
Used edge (5, 6) with cost: 6.00
Used edge (4, 7) with cost: 8.00

As it can be clearly seen that even when there is already an entry for destination node 1 as {from=0, to=1, weight=10.00}) [line1 of output], it adds another entry for it as {from=2, to=1, weight=3.00} [line2 of output] instead of updating existing entry.

When I debugged it by adding a break point inside my custom comparator, what I noticed is that comparator was never called for the existing entry, and therefore comparison with the existing entry didn't happened. For example in this case, while processing edge {from=2, to=1, weight=3.00} Comparator.compare() gets called for entries {from=2, to=3, weight=2.00} and {from=2, to=5, weight=8.00} but doesn't get called for entry {from=0, to=1, weight=10.00}, and therefore it concludes that there is no entry for destination node 1, and so it adds a new entry, and therefore I get two entries for destination node 1 [line2 of output]

I suspect this has something to do with immutability of objects and concurrent modification restrictions in Java Collections framework. But I'm unable to understand the root cause of the issue.

Any help is appreciated.


Solution

  • You Comparator is violating its contract, e.g.

    The implementor must ensure that sgn(compare(x, y)) == -sgn(compare(y, x)) for all x and y. (This implies that compare(x, y) must throw an exception if and only if compare(y, x) throws an exception.)

    This is the compare method, without all the comments:

    public int compare(Edge e1, Edge e2) {
        if (e1.to == e2.to) return 0;
    
        if (e1.wt > e2.wt) return 1;
        else if (e1.wt < e2.wt) return -1;
    
        return -1;
    }
    

    E.g. you have two edges with weight 1:

    a = {from=0, to=2, weight=1.00}
    b = {from=5, to=4, weight=1.00}
    

    Since they have different to values, but same wt values, the method returns -1, regardless of the parameter order, i.e. compare(a, b) = -1 and compare(b, a) = -1.

    That violates the rule listed up top, and will result in unpredictable behavior.