algorithmmachine-learning

Word Prediction algorithm


Consider the following:

  1. We have a word dictionary available
  2. We are fed many paragraphs of words, and I wish to be able to predict the next word in a sentence given this input.

Say we have a few sentences such as "Hello my name is Tom", "His name is jerry", "He goes where there is no water". We check a hash table if a word exists. If it does not, we assign it a unique id and put it in the hash table. This way, instead of storing a "chain" of words as a bunch of strings, we can just have a list of uniqueID's.

Above, we would have for instance (0, 1, 2, 3, 4), (5, 2, 3, 6), and (7, 8, 9, 10, 3, 11, 12). Note that 3 is "is" and we added new unique id's as we discovered new words. So say we are given a sentence "her name is", this would be (13, 2, 3). We want to know, given this context, what the next word should be. This is the algorithm I thought of, but I dont think its efficient:

  1. We have a list of N chains (observed sentences) where a chain may be ex. 3,6,2,7,8.
  2. Each chain is on average size M, where M is the average sentence length
  3. We are given a new chain of size S, ex. 13, 2, 3, and we wish to know what is the most probable next word?

Algorithm:

  1. First scan the entire list of chains for those who contain the full S input(13,2,3, in this example). Since we have to scan N chains, each of length M, and compare S letters at a time, its O(NMS).

  2. If there are no chains in our scan which have the full S, next scan by removing the least significant word (ie. the first one, so remove 13). Now, scan for (2,3) as in 1 in worst case O(NMS) which is really S-1.

  3. Continue scanning this way until we get results > 0 (if ever).

  4. Tally the next words in all of the remaining chains we have gathered. We can use a hash table which counts every time we add, and keeps track of the most added word. O(N) worst case build, O(1) to find max word.

  5. The max word found is the the most likely, so return it.

Each scan takes O(MNS) worst case. This is because there are N chains, each chain has M numbers, and we must check S numbers for overlaying a match. We scan S times worst case (13,2,3,then 2,3, then 3 for 3 scans = S). Thus, the total complexity is O(S^2 * M * N).

So if we have 100,000 chains and an average sentence length of 10 words, we're looking at 1,000,000*S^2 to get the optimal word. Clearly, N >> M, since sentence length does not scale with number of observed sentences in general, so M can be a constant. We can then reduce the complexity to O(S^2 * N). O(S^2 * M * N) may be more helpful for analysis though, since M can be a sizeable "constant".

This could be the complete wrong approach to take for this type of problem, but I wanted to share my thoughts instead of just blatantly asking for assitance. The reason im scanning the way I do is because I only want to scan as much as I have to. If nothing has the full S, just keep pruning S until some chains match. If they never match, we have no idea what to predict as the next word! Any suggestions on a less time/space complex solution?


Solution

  • This is the problem of language modeling. For a baseline approach, The only thing you need is a hash table mapping fixed-length chains of words, say of length k, to the most probable following word.(*)

    At training time, you break the input into (k+1)-grams using a sliding window. So if you encounter

    The wrath sing, goddess, of Peleus' son, Achilles
    

    you generate, for k=2,

    START START the
    START the wrath
    the wrath sing
    wrath sing goddess
    goddess of peleus
    of peleus son
    peleus son achilles
    

    This can be done in linear time. For each 3-gram, tally (in a hash table) how often the third word follows the first two.

    Finally, loop through the hash table and for each key (2-gram) keep only the most commonly occurring third word. Linear time.

    At prediction time, look only at the k (2) last words and predict the next word. This takes only constant time since it's just a hash table lookup.

    If you're wondering why you should keep only short subchains instead of full chains, then look into the theory of Markov windows. If your model were to remember all the chains of words that it has seen in its input, then it would badly overfit its training data and only reproduce its input at prediction time. How badly depends on the training set (more data is better), but for k>4 you'd really need smoothing in your model.

    (*) Or to a probability distribution, but this is not needed for your simple example use case.