javaspell-checkinglevenshtein-distanceedit-distancedamerau-levenshtein

Finding which error(s) are detected by Damerau-Levenshtein edit distance algorithm


I'm creating a spelling correction tool and wanted to implement a noisy channel with Bayes theorem. In order to do so, I need to calculate the probability P(X|W), where X is the given (misspelled) word, and W is the possible correction. The probability is given by getting a value from a confusion matrix, that depends on knowing which type of error happened, meaning that if for example X = "egh" and W = "egg" then the edit distance would be 1, and the error would be a substitution error that happened on character number 2.

I'm trying to find a way to get the error "type" as well as the character it happened for, but can't seem to make it work. I've tried creating a TreeMap and inserting i/j values whenever an error is detected, but it didn't work.

I may assume that there's only one error, meaning that the edit distance is exactly 1.

Here's my code:

public static int DLD(String s1, String s2) {
    if (s1 == null || s2 == null) {  // Invalid input
        return -1;
    }

    if (s1.equals(s2)) {  // No distance to compute
        return 0;
    }

    // The max possible distance
    int inf = s1.length() + s2.length();

    // Create and initialize the character array indices
    HashMap<Character, Integer> da = new HashMap<>();
    for (int i = 0; i < s1.length(); ++i) {
        da.put(s1.charAt(i), 0);
    }
    for (int j = 0; j < s2.length(); ++j) {
        da.put(s2.charAt(j), 0);
    }

    // Create the distance matrix H[0 .. s1.length+1][0 .. s2.length+1]
    int[][] distances = new int[s1.length() + 2][s2.length() + 2];

    // initialize the left and top edges of H
    for (int i = 0; i <= s1.length(); ++i) {
        distances[i + 1][0] = inf;
        distances[i + 1][1] = i;
    }

    for (int j = 0; j <= s2.length(); ++j) {
        distances[0][j + 1] = inf;
        distances[1][j + 1] = j;

    }

    // fill in the distance matrix H
    // look at each character in s1
    for (int i = 1; i <= s1.length(); ++i) {
        int db = 0;

        // look at each character in s2
        for (int j = 1; j <= s2.length(); ++j) {
            int i1 = da.get(s2.charAt(j - 1));
            int j1 = db;

            int cost = 1;
            if (s1.charAt(i - 1) == s2.charAt(j - 1)) {
                cost = 0;
                db = j;
            }

            distances[i + 1][j + 1] = min(
                    distances[i][j] + cost, // substitution
                    distances[i + 1][j] + 1, // insertion
                    distances[i][j + 1] + 1, // deletion
                    distances[i1][j1] + (i - i1 - 1) + 1 + (j - j1 - 1));

        }

        da.put(s1.charAt(i - 1), i);
    }

    return distances[s1.length() + 1][s2.length() + 1];
}

Any hint/direction towards solving this would be much appreciated.

Thanks!

Edit 1: I figured something out and it seems to be working, although I'm not 100% sure. I replaced the code segment where I use the min() method with this:

int sub = distances[i][j] + cost;
int ins = distances[i + 1][j] + 1;
int del = distances[i][j + 1] + 1;
int trans = distances[i1][j1] + (i - i1 - 1) + 1 + (j - j1 - 1);

distances[i + 1][j + 1] = min(sub, ins, del, trans);

if ((distances[i][j] == 0 || distances[i - 1][j] == 0 || 
     distances[i][j - 1] == 0 || distances[i + 1][j + 1] == trans) &&
                    distances[i + 1][j + 1] == 1) {
                
    TreeMap<String, Integer> error = mappingTermAndError.getOrDefault(s2, null);
    if (error != null) {
        error.clear();
    } else {
        error = new TreeMap<>();
    }

    if (distances[i + 1][j + 1] == trans) {
        error.put("trans", i - 2);

    } else if (distances[i + 1][j + 1] == del) {
        error.put("del", i - 1);

    } else if (distances[i + 1][j + 1] == ins) {
        error.put("ins", i - 1);

    } else {  // distances[i + 1][j + 1] == sub
        error.put("sub", i - 1);
    }
    mappingTermAndError.put(s2, error);
}

What it basically does is get the value for each error type, then calculate the minimum. if The new minimum is 1 (so this is the first error) and also one of the previous cells in the distance matrix is 0 (meaning there's a path with no errors leading to that point) or if the error is transposition (which we can only know about after we've already had an error) than I replace the previously registered error with the new one, and get the 'i' corresponding with the character the error was done for.

I'm aware that this solution is pretty ugly and probably not very efficient, so if someone has any thoughts on how to improve that it would be great.


Solution

  • The error type and characters involved have to be stored somewhere. You can have them in separate data structures, or you can have them in encapsulated in objects.

    Here's what it could look like using objects. For simplicity I'm implementing only Levenshtein distance, but I'm sure you can easily apply the technique to Damerau–Levenshtein.

    First you need to define a class that encapsulates the information about an edit: cost, parent, and any extra information like type (replace, insert, delete) or the characters involved. To keep things simple I'm keeping a single string called "type" for this extra info, but you would want to add separate fields for the type of error, the character indices, etc. You may even want to use inheritance to create different subtypes of edits with different behavior.

    class Edit implements Comparable<Edit> {
        int cost;
        Edit parent;
        String type;
    
        public Edit() {
            // create a "start" node with no parent and zero cost
        }
    
        public Edit(String type, Edit parent, int cost) {
            this.type = type;
            this.cost = parent.cost + cost;
            this.parent = parent;
        }
    
        @Override
        public int compareTo(Edit o) {
            return Integer.compare(this.cost, o.cost);
        }
    
        @Override
        public String toString() {
            return type;
        }
    }
    

    Then you use this class instead of just int for the distance table. At 0,0 there is a special start node with no parent. At all other points you choose a node with one parent or another according to the minimum cost it takes to arrive at that node. To be more flexible, let's split out the building of the matrix out of the editDistance method:

    Edit[][] buildMatrix(String s1, String s2) {
        Edit[][] distance = new Edit[s1.length() + 1][s2.length() + 1];
    
        distance[0][0] = new Edit();
        for (int i = 1; i <= s1.length(); i++) {
            distance[i][0] = new Edit("-" + s1.charAt(i - 1), distance[i - 1][0], 1);
        }
        for (int j = 1; j <= s2.length(); j++) {
            distance[0][j] = new Edit("+" + s2.charAt(j - 1), distance[0][j - 1], 1);
        }
    
        for (int i = 1; i <= s1.length(); i++) {
            for (int j = 1; j <= s2.length(); j++) {
                int replaceCost = s1.charAt(i - 1) == s2.charAt(j - 1) ? 0 : 1;
                distance[i][j] = Collections.min(List.of(
                    // replace or same
                    new Edit(s1.charAt(i - 1) + "/" + s2.charAt(j - 1), distance[i - 1][j - 1], replaceCost),
                    // delete
                    new Edit("-" + s1.charAt(i - 1), distance[i - 1][j], 1),
                    // insert
                    new Edit("+" + s2.charAt(j - 1), distance[i][j - 1], 1)));
            }
        }
    
        return distance;
    }
    

    Then the "edit distance" function only needs to take the cost of the last node:

    int editDistance(String s1, String s2) {
        Edit[][] distance = buildMatrix(s1, s2);
        return distance[s1.length()][s2.length()].cost;
    }
    

    But thanks to the "parent" pointers, you can also easily construct the list of edits needed to change one string to the other, also known as a "diff":

    List<Edit> diff(String s1, String s2) {
        Edit[][] distance = buildMatrix(s1, s2);
        List<Edit> diff = new ArrayList<>();
        Edit edit = distance[s1.length()][s2.length()];
        while (edit != distance[0][0]) {
            diff.add(edit);
            edit = edit.parent;
        }
        Collections.reverse(diff);
        return diff;
    }