Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Word Prediction algorithm

I'm sure there is a post on this, but I couldn't find one asking this exact question. Consider the following:

  1. We have a word dictionary available
  2. We are fed many paragraphs of words, and I wish to be able to predict the next word in a sentence given this input.

Say we have a few sentences such as "Hello my name is Tom", "His name is jerry", "He goes where there is no water". We check a hash table if a word exists. If it does not, we assign it a unique id and put it in the hash table. This way, instead of storing a "chain" of words as a bunch of strings, we can just have a list of uniqueID's.

Above, we would have for instance (0, 1, 2, 3, 4), (5, 2, 3, 6), and (7, 8, 9, 10, 3, 11, 12). Note that 3 is "is" and we added new unique id's as we discovered new words. So say we are given a sentence "her name is", this would be (13, 2, 3). We want to know, given this context, what the next word should be. This is the algorithm I thought of, but I dont think its efficient:

  1. We have a list of N chains (observed sentences) where a chain may be ex. 3,6,2,7,8.
  2. Each chain is on average size M, where M is the average sentence length
  3. We are given a new chain of size S, ex. 13, 2, 3, and we wish to know what is the most probable next word?

Algorithm:

  1. First scan the entire list of chains for those who contain the full S input(13,2,3, in this example). Since we have to scan N chains, each of length M, and compare S letters at a time, its O(N*M*S).

  2. If there are no chains in our scan which have the full S, next scan by removing the least significant word (ie. the first one, so remove 13). Now, scan for (2,3) as in 1 in worst case O(N*M*S) which is really S-1.

  3. Continue scanning this way until we get results > 0 (if ever).

  4. Tally the next words in all of the remaining chains we have gathered. We can use a hash table which counts every time we add, and keeps track of the most added word. O(N) worst case build, O(1) to find max word.

  5. The max word found is the the most likely, so return it.

Each scan takes O(M*N*S) worst case. This is because there are N chains, each chain has M numbers, and we must check S numbers for overlaying a match. We scan S times worst case (13,2,3,then 2,3, then 3 for 3 scans = S). Thus, the total complexity is O(S^2 * M * N).

So if we have 100,000 chains and an average sentence length of 10 words, we're looking at 1,000,000*S^2 to get the optimal word. Clearly, N >> M, since sentence length does not scale with number of observed sentences in general, so M can be a constant. We can then reduce the complexity to O(S^2 * N). O(S^2 * M * N) may be more helpful for analysis though, since M can be a sizeable "constant".

This could be the complete wrong approach to take for this type of problem, but I wanted to share my thoughts instead of just blatantly asking for assitance. The reason im scanning the way I do is because I only want to scan as much as I have to. If nothing has the full S, just keep pruning S until some chains match. If they never match, we have no idea what to predict as the next word! Any suggestions on a less time/space complex solution? Thanks!

like image 580
user2045279 Avatar asked Sep 10 '13 20:09

user2045279


People also ask

How does word prediction work?

Word prediction programs prompt the user with a list of likely word choices based on words previously typed. Some word prediction software automatically collects new words as they are used and considers a person's common vocabulary when predicting words in the future.

Which algorithm can be used for prediction?

Naive Bayes is a simple but surprisingly powerful algorithm for predictive modeling. The model consists of two types of probabilities that can be calculated directly from your training data: 1) The probability of each class; and 2) The conditional probability for each class given each x value.

Is word prediction AI?

Microsoft is adding AI-powered text predictions to Word.


2 Answers

This is the problem of language modeling. For a baseline approach, The only thing you need is a hash table mapping fixed-length chains of words, say of length k, to the most probable following word.(*)

At training time, you break the input into (k+1)-grams using a sliding window. So if you encounter

The wrath sing, goddess, of Peleus' son, Achilles

you generate, for k=2,

START START the
START the wrath
the wrath sing
wrath sing goddess
goddess of peleus
of peleus son
peleus son achilles

This can be done in linear time. For each 3-gram, tally (in a hash table) how often the third word follows the first two.

Finally, loop through the hash table and for each key (2-gram) keep only the most commonly occurring third word. Linear time.

At prediction time, look only at the k (2) last words and predict the next word. This takes only constant time since it's just a hash table lookup.

If you're wondering why you should keep only short subchains instead of full chains, then look into the theory of Markov windows. If your model were to remember all the chains of words that it has seen in its input, then it would badly overfit its training data and only reproduce its input at prediction time. How badly depends on the training set (more data is better), but for k>4 you'd really need smoothing in your model.

(*) Or to a probability distribution, but this is not needed for your simple example use case.

like image 198
Fred Foo Avatar answered Oct 20 '22 21:10

Fred Foo


Yeh Whye Teh also has some recent interesting work that addresses this problem. The "Sequence Memoizer" extends the traditional prediction-by-partial-matching scheme to take into account arbitrarily long histories.

Here is a link the original paper: http://www.stats.ox.ac.uk/~teh/research/compling/WooGasArc2011a.pdf

It is also worth reading some of the background work, which can be found in the paper "A Bayesian Interpretation of Interpolated Kneser-Ney"

like image 33
user1149913 Avatar answered Oct 20 '22 22:10

user1149913