javaperformanceoptimizationmemory-managementunion-find

How to optimize a Java Union-Find program to avoid OutOfMemoryError when processing large datasets


This is a follow-up to my earlier question

I've managed to implement a working solution:

package com.test;

import java.io.*;
import java.util.*;

public class LineGroupProcessor {

    private LineGroupProcessor() {
    }

    public static void main(String[] args) {
        validateArgs(args);
        List<String[]> validRows = readValidRows(args[0]);
        UnionFind unionFind = new UnionFind(validRows.size());
        Map<String, Integer> columnValueMap = new HashMap<>();
        for (int i = 0; i < validRows.size(); i++) {
            processRow(validRows, columnValueMap, unionFind, i);
        }
        writeOutput(groupAndSortRows(validRows, unionFind));
    }

    private static void validateArgs(String[] args) {
        if (args.length == 0) {
            throw new IllegalArgumentException("No input file provided. Please specify a text or CSV file.");
        }

        String filePath = args[0];
        if (!filePath.endsWith(".txt") && !filePath.endsWith(".csv")) {
            throw new IllegalArgumentException("Invalid file type. Please provide a text or CSV file.");
        }

        File file = new File(filePath);
        if (!file.exists() || !file.isFile()) {
            throw new IllegalArgumentException("File does not exist or is not a valid file: " + filePath);
        }
    }

    private static List<String[]> readValidRows(String filePath) {
        List<String[]> rows = new ArrayList<>();
        try (BufferedReader br = new BufferedReader(new FileReader(filePath))) {
            String line;
            while ((line = br.readLine()) != null) {
                String[] columns = line.split(";");
                if (isValidRow(columns)) {
                    rows.add(columns);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return rows;
    }

    private static boolean isValidRow(String[] columns) {
        for (String column : columns) {
            if (column.isEmpty() && !column.matches("^\"\\d{11}\"$")) {
                return false;
            }
        }
        return true;
    }

    private static void processRow(List<String[]> rows, Map<String, Integer> columnValueMap, UnionFind uf, int rowIndex) {
        String[] row = rows.get(rowIndex);
        for (int j = 0; j < row.length; j++) {
            String value = row[j].trim();
            if (!value.isEmpty() && !value.equals("\"\"")) {
                StringBuilder keyBuilder = new StringBuilder();
                keyBuilder.append(j).append(",").append(value);
                String key = keyBuilder.toString();
                if (columnValueMap.containsKey(key)) {
                    int prevRowIdx = columnValueMap.get(key);
                    uf.union(rowIndex, prevRowIdx);
                } else {
                    columnValueMap.put(key, rowIndex);
                }
            }
        }
    }

    private static List<Set<String>> groupAndSortRows(List<String[]> rows, UnionFind uf) {
        Map<Integer, Set<String>> groups = new HashMap<>();
        for (int i = 0; i < rows.size(); i++) {
            int group = uf.find(i);
            groups.computeIfAbsent(group, k -> new HashSet<>()).add(Arrays.toString(rows.get(i)));
        }

        List<Set<String>> sortedGroups = new ArrayList<>(groups.values());
        sortedGroups.sort((g1, g2) -> Integer.compare(g2.size(), g1.size()));
        return sortedGroups;
    }

    private static void writeOutput(List<Set<String>> sortedGroups) {
        long groupsWithMoreThanOneRow = sortedGroups.stream().filter(group -> group.size() > 1).count();
        try (PrintWriter writer = new PrintWriter("output.txt")) {
            writer.println("Total number of groups with more than one element: " + groupsWithMoreThanOneRow);
            writer.println();
            int groupNumber = 1;
            for (Set<String> group : sortedGroups) {
                writer.println("Group " + groupNumber);
                for (String row : group) {
                    writer.println(row);
                }
                writer.println();
                groupNumber++;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

package com.test;

public class UnionFind {

    private final int[] parent;
    private final int[] rank;

    public UnionFind(int size) {
        parent = new int[size];
        rank = new int[size];
        for (int i = 0; i < size; i++) {
            parent[i] = i;
            rank[i] = 0;
        }
    }

    public int find(int index) {
        if (parent[index] != index) {
            parent[index] = find(parent[index]);
        }
        return parent[index];
    }

    public void union(int index1, int index2) {
        int element1 = find(index1);
        int element2 = find(index2);
        if (element1 != element2) {
            if (rank[element1] > rank[element2]) {
                parent[element2] = element1;
            } else if (rank[element1] < rank[element2]) {
                parent[element1] = element2;
            } else {
                parent[element2] = element1;
                rank[element1]++;
            }
        }
    }
}

The program has specific requirements: it should complete within 30 seconds and use a maximum of 1GB of memory (-Xmx1G).

When running test datasets of 1 million and 10 million rows, I get the following errors:

> Task :com.test.LineGroupProcessor.main()
Exception in thread "main" java.lang.OutOfMemoryError: Java heap space
    at com.test.LineGroupProcessor.lambda$groupAndSortRows$0(LineGroupProcessor.java:85)
    at com.test.LineGroupProcessor$$Lambda/0x000002779d000400.apply(Unknown Source)
    at java.base/java.util.HashMap.computeIfAbsent(HashMap.java:1228)
    at com.test.LineGroupProcessor.groupAndSortRows(LineGroupProcessor.java:85)
    at com.test.LineGroupProcessor.main(LineGroupProcessor.java:19)
    
> Task :com.test.LineGroupProcessor.main()
Exception in thread "main" java.lang.OutOfMemoryError: Java heap space: failed reallocation of scalar replaced objects
    at java.base/java.util.HashMap.computeIfAbsent(HashMap.java:1222)
    at com.test.LineGroupProcessor.groupAndSortRows(LineGroupProcessor.java:85)
    at com.test.LineGroupProcessor.main(LineGroupProcessor.java:19) 

How can I optimize the code to stay within the 1GB memory limit?


Solution

  • Your current approach (roughly)

    1. read the entire file into memory into List<String[]> – so if it's 1 million lines, you'll get a List of 1 million elements. Also, instead of storing each line of text as a String like "200";"123";"100", you're creating three (3) separate Strings ("200", "123", and "100") and an array.
    2. iterate through the List to construct an instance of UnionFind
    3. iterate once more through the List to build a HashMap named "groups" which includes a copy of the line input (the original line among the 1 million lines from the file)
    4. some sorting logic, then final pass through sorted data to print things out

    A few observations

    Suggestions