pythonrecursiontrieprefix-tree

Counting word strokes while parsing Trie tree


I'm trying to solve the keyboard autocompletion problem described here. The problem is to calculate how many keystrokes a word requires, given some dictionary and autocomplete rules. For example, for the dictionary:

data = ['hello', 'hell', 'heaven', 'goodbye']

We get the following results (please refer to the link above for further explanations):

{'hell': 2, 'heaven': 2, 'hello': 3, 'goodbye': 1}

Quick explanation: if the user types h, then e is autocompleted because all words starting with h also have e as second letter. Now if the user types in l, the other l is filled, giving 2 strokes for the word hell. Of course, hello would require one more stroke. Please, see the link above for more examples.

My Trie code is the following, and it works fine (taken from https://en.wikipedia.org/wiki/Trie). The Stack code is to parse the tree from root (see edit below):

class Stack(object):
    def __init__(self, size):
        self.data = [None]*size
        self.i = 0
        self.size = size
        
    def pop(self):
        if self.i == 0:
            return None
        item = self.data[self.i - 1]
        self.i-= 1
        return item
    
    def push(self, item):
        if self.i >= self.size:
            return None
        self.data[self.i] = item
        self.i+= 1
        return item
        
    def __str__(self):
        s = '# Stack contents #\n'
        if self.i == 0:
            return
        for idx in range(self.i - 1, -1, -1):
            s+= str(self.data[idx]) + '\n'
        return s

class Trie(object):
    def __init__(self, value, children):
        self.value = value #char
        self.children = children #{key, trie}

class PrefixTree(object):
    def __init__(self, data):
        self.root = Trie(None, {})
        self.data = data
        
        for w in data:
            self.insert(w, w)
    
    def insert(self, string, value):
        node = self.root
        i = 0
        n = len(string)
        
        while i < n:
            if string[i] in node.children:
                node = node.children[string[i]]
                i = i + 1
            else:
                break
        
        while i < n:
            node.children[string[i]] = Trie(string[:i], {})
            node = node.children[string[i]]
            i = i + 1
            
        node.value = value
        
        
    def find(self, key):
        node = self.root
        for char in key:
            if char in node.children:
                node = node.children[char]
            else:
                return None
        return node

I couldn't figure it out how to count the number of strokes:

data = ['hello', 'hell', 'heaven', 'goodbye']
tree = PrefixTree(data)
strokes = {w:1 for w in tree.data} #at least 1 stroke is necessary

And here's the code to parse the tree from the root:

stack = Stack(100)
stack.push((None, pf.root))
print 'Key\tChilds\tValue'
print '--'*25

strokes = {}
    
while stack.i > 0:
    key, curr = stack.pop()

    # if something:
         #update strokes

    print '%s\t%s\t%s' % (key, len(curr.children), curr.value)
    for key, node in curr.children.items():
        stack.push((key, node))
 
print strokes

Any idea or constructive comment would help, thanks!

Edit

Great answer by @SergiyKolesnikov. There's one small change that can be done in order to avoid the call to endsWith(). I just added a boolean field to the Trie class:

class Trie(object):
    def __init__(self, value, children, eow):
        self.value = value #char
        self.children = children #{key, trie}
        self.eow = eow # end of word
    

And at the end of insert():

def insert(self, string, value):
#...
    node.value = value
    node.eow = True

Then just replace curr.value.endswith('$'): with curr.eow. Thank you all!


Solution

  • The trie for your example can look like this

     ' '
    |    \
    H     G
    |     |
    E     O
    | \   |
    L  A  O
    |  |  |
    L$ V  D
    |  |  |
    O  E  B
       |  |
       N  Y
          |
          E
    

    What nodes in the trie can be seen as markers for user key strokes? There are two types of such nodes:

    1. Inner nodes with more than one child, because the user has to choose among multiple alternatives.
    2. Nodes that represent the last letter of a word, but are not leaves (marked with $), because the user has to type the next letter if the current word is not what is needed.

    While traversing the trie recursively one counts how many of these marker nodes were encountered before the last letter of a word was reached. This count is the number of strokes needed for the word.

    For the word "hell" it is two marker nodes: ' ' and E (2 strokes).
    For the word "hello" it is three marker nodes: ' ', E, L$ (3 strokes).
    And so on...

    What needs to be changed in your implementation:

    The end of a valid word needs to be marked in the tree, so that the second condition can be checked. Therefore, we change the last line of the PrefixTree.insert() method from

    node.value = value
    

    to

    node.value = value + '$'
    

    Now we add a stroke counter for each stack item (the last value in the triple pushed to the stack) and the checks that increase the counter:

    stack = Stack(100)
    stack.push((None, tree.root, 0)) # We start with stroke counter = 0
    print('Key\tChilds\tValue')
    print('--'*25)
    
    strokes = {}
    
    while stack.i > 0:
        key, curr, stroke_counter = stack.pop()
    
        if curr.value is not None and curr.value.endswith('$'):
            # The end of a valid word is reached. Save the word and the corresponding stroke counter.
            strokes[curr.value[:-1]] = stroke_counter
    
        if len(curr.children) > 1:
            # Condition 2 is true. Increase the stroke counter.
            stroke_counter += 1
        if curr.value is not None and curr.value.endswith('$') and len(curr.children) > 0:
            # Condition 1 is true. Increase the stroke counter.
            stroke_counter += 1
    
        print('%s\t%s\t%s' % (key, len(curr.children), curr.value))
        for key, node in curr.children.items():
            stack.push((key, node, stroke_counter)) # Save the stroke counter
    
    print(strokes)
    

    Output:

    Key Childs  Value
    --------------------------------------------------
    None    2   None
    h   1   
    e   2   h
    a   1   he
    v   1   hea
    e   1   heav
    n   0   heaven$
    l   1   he
    l   1   hell$
    o   0   hello$
    g   1   
    o   1   g
    o   1   go
    d   1   goo
    b   1   good
    y   1   goodb
    e   0   goodbye$
    {'heaven': 2, 'goodbye': 1, 'hell': 2, 'hello': 3}