Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Injecting pre-trained word2vec vectors into TensorFlow seq2seq

I was trying to inject pretrained word2vec vectors into existing tensorflow seq2seq model.

Following this answer, I produced the following code. But it doesn't seem to improve performance as it should, although the values in the variable are updated.

In my understanding the error might be due to the fact that EmbeddingWrapper or embedding_attention_decoder create embeddings independently of the vocabulary order?

What would be the best way to load pretrained vectors into tensorflow model?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding"
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding"


def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size):
  word2vec_model = word2vec.load(word2vec_path, encoding="latin-1")
  print("w2v model created!")
  session.run(tf.initialize_all_variables())

  assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size)
  assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size)


def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size):
  vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name]
  if len(vectors_variable) != 1:
      print("Word vector variable not found or too many. key: " + embedding_key)
      print("Existing embedding trainable variables:")
      print([v.name for v in tf.trainable_variables() if "embedding" in v.name])
      sys.exit(1)

  vectors_variable = vectors_variable[0]
  vectors = vectors_variable.eval()

  with gfile.GFile(vocab_path, mode="r") as vocab_file:
      counter = 0
      while counter < vocab_size:
          vocab_w = vocab_file.readline().replace("\n", "")
          # for each word in vocabulary check if w2v vector exist and inject.
          # otherwise dont change the value.
          if word2vec_model.__contains__(vocab_w):
              w2w_word_vector = word2vec_model.get_vector(vocab_w)
              vectors[counter] = w2w_word_vector
          counter += 1

  session.run([vectors_variable.initializer],
            {vectors_variable.initializer.inputs[1]: vectors})
like image 763
Vlad Kolesnyk Avatar asked Apr 02 '16 07:04

Vlad Kolesnyk


1 Answers

I am not familiar with the seq2seq example, but in general you can use the following code snippet to inject your embeddings:

Where you build you graph:

with tf.device("/cpu:0"):
  embedding = tf.get_variable("embedding", [vocabulary_size, embedding_size])      
  inputs = tf.nn.embedding_lookup(embedding, input_data)

When you execute (after building your graph and before stating the training), just assign your saved embeddings to the embedding variable:

session.run(tf.assign(embedding, embeddings_that_you_want_to_use))

The idea is that the embedding_lookup will replace input_data values with those present in the embedding variable.

like image 59
RaduK Avatar answered Sep 21 '22 08:09

RaduK