Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow implementation of word2vec

The Tensorflow tutorial here refers to their basic implementation which you can find on github here, where the Tensorflow authors implement word2vec vector embedding training/evaluation with the Skipgram model.

My question is about the actual generation of (target, context) pairs in the generate_batch() function.

On this line Tensorflow authors randomly sample nearby target indices from the "center" word index in the sliding window of words.

However, they also keep a data structure targets_to_avoid to which they add first the "center" context word (which of course we don't want to sample) but ALSO other words after we add them.

My questions are as follows:

  1. Why sample from this sliding window around the word, why not just have a loop and use them all rather than sampling? It seems strange they would worry about performance/memory in word2vec_basic.py (their "basic" implementation).
  2. Whatever the answer to 1) is, why are they both sampling and keeping track of what they've selected with targets_to_avoid? If they wanted truly random, they'd use selection with replacement, and if they wanted to ensure they got all the options, they should have just used a loop and gotten them all in the first place!
  3. Does the built in tf.models.embedding.gen_word2vec work this way too? If so where can I find the source code? (couldn't find the .py file in the Github repo)

Thanks!

like image 791
lollercoaster Avatar asked Jun 29 '16 22:06

lollercoaster


People also ask

How is Word2Vec implemented?

To implement Word2Vec, there are two flavors to choose from — Continuous Bag-Of-Words (CBOW) or continuous Skip-gram (SG). In short, CBOW attempts to guess the output (target word) from its neighbouring words (context words) whereas continuous Skip-Gram guesses the context words from a target word.

Which is better TF-IDF or Word2Vec?

Some key differences between TF-IDF and word2vec is that TF-IDF is a statistical measure that we can apply to terms in a document and then use that to form a vector whereas word2vec will produce a vector for a term and then more work may need to be done to convert that set of vectors into a singular vector or other ...


2 Answers

I tried out your proposed way to generate batches - having a loop and using the whole skip-window. The results are:

1. Faster generation of batches

For a batch size of 128 and a skip window of 5

  • generating batches by looping over the data one by one takes 0.73s per 10,000 batches
  • generating batches with the tutorial code and num_skips=2 takes 3.59s per 10,000 batches

2. Higher danger of overfitting

Keeping the rest of the tutorial code as it is, I trained the model with both ways and logged the average loss every 2000 steps:

enter image description here

This pattern occurred repeatedly. It shows that using 10 samples per word instead of 2 can cause overfitting.

Here is the code that I used for generating the batches. It replaces the tutorial's generate_batch function.

data_index = 0

def generate_batch(batch_size, skip_window):
    global data_index
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)  # Row
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)  # Column

    # For each word in the data, add the context to the batch and the word to the labels
    batch_index = 0
    while batch_index < batch_size:
        context = data[get_context_indices(data_index, skip_window)]

        # Add the context to the remaining batch space
        remaining_space = min(batch_size - batch_index, len(context))
        batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
        labels[batch_index:batch_index + remaining_space] = data[data_index]

        # Update the data_index and the batch_index
        batch_index += remaining_space
        data_index = (data_index + 1) % len(data)

    return batch, labels

Edit: The get_context_indices is a simple function, which returns the index slice in the skip_window around data_index. See the slice() documentation for more info.

like image 198
Kilian Batzner Avatar answered Oct 21 '22 09:10

Kilian Batzner


There is a parameter named num_skips which denotes the number of (input, output) pairs generated from the single window: [skip_window target skip_window]. So num_skips restrict the number of context words we would use as output words. And that is why the generate_batch function assert num_skips <= 2*skip_window. The code just randomly pick up num_skip context words to construct training pairs with target. But I don't know how num_skips affects the performance.

like image 41
user1903382 Avatar answered Oct 21 '22 09:10

user1903382