Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Training a RNN to output word2vec embedding instead of logits

Traditionally it seems that RNNs use logits to predict next time step in the sequence. In my case I need the RNN to output a word2vec (50 depth) vector prediction. This means that the cost function has be based off 2 vectors: Y the actual vector of the next word in the series and Y_hat, the network prediction.

I've tried using a cosine distance cost function but the network does not seem to learn (I've let it run other 10 hours on a AWS P3 and the cost is always around 0.7)

Is such a model possible at all ? If so what cost function should be used ?

Cosine distance in TF:

cosine_distance = tf.losses.cosine_distance(tf.nn.l2_normalize(outputs, 2), tf.nn.l2_normalize(targets, 2), axis=2)

Update: I am trying to predict a word2vec so during sampling I could pick next word based on the closest neighbors of the predicted vector.

like image 717
silkAdmin Avatar asked Feb 17 '18 16:02

silkAdmin


1 Answers

What is the reason that you want to predict a word embedding? Where are you getting the "ground truth" word embeddings from? For word2vec models you typically will re-use the trained word-embeddings in future models. If you trained a word2vec model with an embedding size of 50, then you would have 50-d embeddings that you could save and use in future models. If you just want to re-create an existing ground truth word2vec model, then you could just use those values. Typical word2vec would be having regular softmax outputs via continuous-bag-of-words or skip-gram and then saving the resulting word embeddings.

If you really do have a reason for trying to generate a model that creates tries to match word2vec, then looking at your loss function here are a few suggestions. I do not believe that you should be normalizing your outputs or your targets -- you probably want those to remain unaffected (the targets are no longer the "ground truth" targets if you have normalized them. Also, it appears you are using dim=0 which has now been deprecated and replaced with axis. Did you try different values for dim? This should represent the dimension along which to compute the cosine distance and I think that the 0th dimension would be the wrong dimension (as this likely should be the batch size. I would try with values of axis=-1 (last dimension) or axis=1 and see if you observe any difference.

Separately, what is your optimizer/learning rate? If the learning rate is too small then you may not actually be able to move enough in the right direction.

like image 113
reese0106 Avatar answered Sep 21 '22 13:09

reese0106