javaalgorithmperformance

Java implemention Dinic's algorithm performance problem


I hope to implement Dinic's algorithm using Java, and I have found a strange problem.

My graph vertex name use string type, and when this string uses pure numbers, such as 1, 2, 3 ,,, 200, At this point, its execution speed is very fast.

However, if I add a prefix to the node name, the execution speed of this code will become very slow with the length of the prefix string, which is difficult to understand.

My algorithm implementation code:

package org.apache.misc.alg.dag;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

public class DinicCalculator<T> implements MaxAntichainCalculator<T> {

    private final Map<String, Map<String, Integer>> network;
    private List<String> nodes;
    private int[] level;

    public DinicCalculator() {
        network = new HashMap<>();
        nodes = new ArrayList<>();
        nodes.add("src");
        nodes.add("sink");
    }

    private void bfs(String source) {
        level = new int[nodes.size()];
        Arrays.fill(level, -1);
        level[nodes.indexOf(source)] = 0;

        Queue<String> queue = new LinkedList<>();
        queue.offer(source);

        while (!queue.isEmpty()) {
            String u = queue.poll();
            for (Map.Entry<String, Integer> entry : network.get(u).entrySet()) {
                String v = entry.getKey();
                int capacity = entry.getValue();
                if (capacity > 0 && level[nodes.indexOf(v)] == -1) {
                    level[nodes.indexOf(v)] = level[nodes.indexOf(u)] + 1;
                    queue.offer(v);
                }
            }
        }
    }

    private int dfs(String u, int flow, String sink) {
        if (u.equals(sink)) {
            return flow;
        }

        for (Map.Entry<String, Integer> entry : network.get(u).entrySet()) {
            String v = entry.getKey();
            int capacity = entry.getValue();
            if (capacity > 0 && level[nodes.indexOf(u)] < level[nodes.indexOf(v)]) {
                int sent = dfs(v, Math.min(flow, capacity), sink);
                if (sent > 0) {
                    network.get(u).put(v, capacity - sent);
                    network.get(v).put(u, network.get(v).getOrDefault(u, 0) + sent);
                    return sent;
                }
            }
        }
        return 0;
    }

    private void addEdge(String from, String to, int capacity) {
        network.computeIfAbsent(from, k -> new HashMap<>()).put(to, capacity);
        network.computeIfAbsent(to, k -> new HashMap<>()).put(from, 0);
        if (!nodes.contains(from)) nodes.add(from);
        if (!nodes.contains(to)) nodes.add(to);
    }

    private Set<String> reach(Map<T, Set<T>> graph, T t, Set<String> visited) {
        Queue<T> queue = new LinkedList<>();
        queue.add(t);

        while (!queue.isEmpty()) {
            T current = queue.poll();
            String currentKey = "A" + current.toString();
            visited.add(currentKey);
            for (T neighbor : graph.get(current)) {
                String neighborKey = "B" + neighbor.toString();
                if (!visited.contains(neighborKey)) {
                    queue.add(neighbor);
                    visited.add(neighborKey);
                }
            }
        }

        return visited;
    }

    // entrance
    public int calculator(Map<T, Set<T>> graph) {

        for (T t : graph.keySet()) {
            addEdge("src", "A" + t.toString(), 1);
            addEdge("B" + t, "sink", 1);
            Set<String> visitedSubset = new HashSet<>();
            for (String u : reach(graph, t, visitedSubset)) {
                addEdge("A" + t, u, 1);
            }
        }

        int maxFlow = 0;
        while (true) {
            bfs("src");
            if (level[nodes.indexOf("sink")] == -1) {
                break;
            }

            int flow;
            while ((flow = dfs("src", Integer.MAX_VALUE, "sink")) > 0) {
                maxFlow += flow;
            }
        }

        return graph.size() - maxFlow;
    }
}

My test code:

package org.apache.misc.alg.dag;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DagTests {

    private static final Logger logger = LoggerFactory.getLogger(DagTests.class);
   
    @Test
    public void test() {
        // Test prefixes of different lengths
        // like 1,2,3,4,,,,,200
        test1("");
        // like A1,A2,A3,A4,,,,,A200
        test1("A");
        test1("AA");
        test1("AAA");
        test1("x");
        test1("xx");
        // like xx_1,xx_2,xx_3,,,,xx_200
        test1("xx_");
    }

    public void test1(String prefix) {
        Map<String, Set<String>> graph = genGraph(prefix);
        long t1 = System.currentTimeMillis();
        int result = new DinicCalculator<String>().calculator(graph);
        logger.info("DinicCalculator with prefix: " + prefix + ", result: " + result + ", time: " + (System.currentTimeMillis() - t1));
    }

    private Map<String, Set<String>> genGraph(String prefix) {
        Map<String, Set<String>> graph = new HashMap<>();
        String end = null;
        for (int i = 0; i < 200; i++) {
            String i1 = prefix + i;
            String i2 = prefix + (i + 1);
            graph.put(i1, new HashSet<>(Arrays.asList(i2)));
            end = i2;
        }

        graph.put(end, new HashSet<>());
        return graph;
    }
}

My test code output:


18:21:24.609 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: , result: 1, time: 503
18:21:27.137 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: A, result: 1, time: 2526
18:21:48.843 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: AA, result: 1, time: 21706
18:21:55.826 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: AAA, result: 1, time: 6983
18:21:57.199 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: x, result: 1, time: 1373
19:35:07.166 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: xx, result: 1, time: 4389965
19:45:18.590 [main] INFO org.apache.misc.alg.dag.DagTests -- DinicCalculator with prefix: xx_, result: 1, time: 611424

Test info:

I have a similar effect when using x64+Ubuntu 22.04+JDK 1.8, also in x64+centos7.5 + jdk1.8.

So where exactly is the problem, could it be caused by CPU cache?


Solution

  • Your dfs method is missing a visited check. You're in practice doing an unrestricted search, which is very slow.

    Also, I changed your nodes field to be a Map<String, Integer>, instead of List<String>, to avoid costly nodes.indexOf calls.

    After that, this is what I see:

    DinicCalculator with prefix: , result: 1, time: 85
    DinicCalculator with prefix: A, result: 1, time: 102
    DinicCalculator with prefix: AA, result: 1, time: 104
    DinicCalculator with prefix: AAA, result: 1, time: 91
    DinicCalculator with prefix: x, result: 1, time: 67
    DinicCalculator with prefix: xx, result: 1, time: 66
    DinicCalculator with prefix: xx_, result: 1, time: 83
    

    Here's the code for DinicCalculator:

    public class DinicCalculator<T> {
    
        private final Map<String, Map<String, Integer>> network;
        private Map<String, Integer> nodes;
        private int[] level;
    
        public DinicCalculator() {
            network = new HashMap<>();
            nodes = new HashMap<>();
            nodes.put("src", nodes.size());
            nodes.put("sink", nodes.size());
        }
    
        private void bfs(String source) {
            level = new int[nodes.size()];
            Arrays.fill(level, -1);
            level[nodes.get(source)] = 0;
    
            Queue<String> queue = new LinkedList<>();
            queue.offer(source);
    
            while (!queue.isEmpty()) {
                String u = queue.poll();
                for (Map.Entry<String, Integer> entry : network.get(u).entrySet()) {
                    String v = entry.getKey();
                    int capacity = entry.getValue();
                    if (capacity > 0 && level[nodes.get(v)] == -1) {
                        level[nodes.get(v)] = level[nodes.get(u)] + 1;
                        queue.offer(v);
                    }
                }
            }
        }
    
        private int dfs(String u, int flow, String sink, HashSet<String> visited) {
            if (visited.contains(u)) {
                return 0;
            }
            visited.add(u);
    
            if (u.equals(sink)) {
                return flow;
            }
    
            for (Map.Entry<String, Integer> entry : network.get(u).entrySet()) {
                String v = entry.getKey();
                int capacity = entry.getValue();
                if (capacity > 0 && level[nodes.get(u)] < level[nodes.get(v)]) {
                    int sent = dfs(v, Math.min(flow, capacity), sink, visited);
                    if (sent > 0) {
                        network.get(u).put(v, capacity - sent);
                        network.get(v).put(u, network.get(v).getOrDefault(u, 0) + sent);
                        return sent;
                    }
                }
            }
            return 0;
        }
    
        private void addEdge(String from, String to, int capacity) {
            network.computeIfAbsent(from, k -> new HashMap<>()).put(to, capacity);
            network.computeIfAbsent(to, k -> new HashMap<>()).put(from, 0);
            if (!nodes.containsKey(from)) nodes.put(from, nodes.size());
            if (!nodes.containsKey(to)) nodes.put(to, nodes.size());
        }
    
        private Set<String> reach(Map<T, Set<T>> graph, T t, Set<String> visited) {
            Queue<T> queue = new LinkedList<>();
            queue.add(t);
    
            while (!queue.isEmpty()) {
                T current = queue.poll();
                String currentKey = "A" + current.toString();
                visited.add(currentKey);
                for (T neighbor : graph.get(current)) {
                    String neighborKey = "B" + neighbor.toString();
                    if (!visited.contains(neighborKey)) {
                        queue.add(neighbor);
                        visited.add(neighborKey);
                    }
                }
            }
    
            return visited;
        }
    
        // entrance
        public int calculator(Map<T, Set<T>> graph) {
    
            for (T t : graph.keySet()) {
                addEdge("src", "A" + t.toString(), 1);
                addEdge("B" + t, "sink", 1);
                Set<String> visitedSubset = new HashSet<>();
                for (String u : reach(graph, t, visitedSubset)) {
                    addEdge("A" + t, u, 1);
                }
            }
    
            int maxFlow = 0;
            while (true) {
                bfs("src");
                if (level[nodes.get("sink")] == -1) {
                    break;
                }
    
                int flow;
                while ((flow = dfs("src", Integer.MAX_VALUE, "sink", new HashSet<>())) > 0) {
                    maxFlow += flow;
                }
            }
    
            return graph.size() - maxFlow;
        }
    }