Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to predict a simple sequence using seq2seq from tensorflow?

I've recently started working with tensorflow so I'm still struggling with basics.

I wanted to create simple seq2seq prediction.

  • Input is list of numbers between 0 and 1.
  • Output is first number from list and the rest of the numbers multiplied by first.

I managed to evaluate model performance and optimize weights. The thing I've been struggling is how to make predictions with trained model.

 model_outputs, states = seq2seq.basic_rnn_seq2seq(encoder_inputs,
                                                  decoder_inputs,
                                                  rnn_cell.BasicLSTMCell(data_point_dim, state_is_tuple=True))

In order to generate model_outputs I need both input and output values for the model, which is good for evaluation but in prediction I only have input values. I'm guessing I need to do something with states but I'm unsure how to transform them into sequence of floats.

Full code is available here https://gist.github.com/anonymous/be405097927758acca158666854600a2

like image 557
Daniel Lyam Montross Avatar asked Jun 27 '16 09:06

Daniel Lyam Montross


1 Answers

When you're training, you give the decoder input at each decoder timestep as the desired output. When testing, you do not have the desired output, so the best you can do is sample an output. This will be the input to the next timestep.

TLDR; Feed in the decoder output at each timestep as the input for the next timestep.

Edit: Some TF codes

The basic_rnn_seq2seq function returns rnn_decoder(decoder_inputs, enc_states[-1], cell)

let's look at the rnn_decoder: def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, scope=None): ....

loop_function: if not None, this function will be applied to i-th output in order to generate i+1-th input, and decoder_inputs will be ignored, except for the first element ("GO" symbol). This can be used for decoding, but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.

During decoding, you need to set this loop_function=True

I recommend looking at the translate.py file in Tensorflow seq2seq library to see how this is handled.

like image 87
user4383691 Avatar answered Oct 11 '22 19:10

user4383691