Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: How to pass output from previous time-step as input to next timestep

It is a duplicate of this question How can I feed last output y(t-1) as input for generating y(t) in tensorflow RNN?

I want to pass the output of RNN at time-step T as the input at time-step T+1. input_RNN(T+1) = output_RNN(T) As per the documentation, the tf.nn.rnn as well as tf.nn.dynamic_rnn functions explicitly take the complete input to all time-steps.

I checked the seq2seq example at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/seq2seq.py It uses a loop and calls the cell(input,state) function. The cell can be lstm or gru or any other rnn cell. I checked the documentation to find the data type and shape of the arguments to cell(), but I found only the contructor of the form cell(num_neurons). I would like to know the correct way of passing output to input. I don't want to use other libraries/wrappers like keras built over tensorflow. Any suggestions?

like image 343
Ruppesh Nalwaya Avatar asked Sep 24 '16 21:09

Ruppesh Nalwaya


1 Answers

One way to do this is to write your own RNN cell, together with your own Multi-RNN cell. This way you can internally store the output of the last RNN cell and just access it in the next time step. Check this blogpost for more info. You can also add e.g. encoder or decoders directly in the cell, so that you can process the data before feeding it to the cell or after retrieving it from the cell.

Another possibility is to use the function tf.nn.raw_rnn which lets you control what happens before and after the calls to the RNN cells. The following code snippet shows how to use this function, credits go to this article.

from tensorflow.python.ops.rnn import _transpose_batch_time
import tensorflow as tf


def sampling_rnn(self, cell, initial_state, input_, seq_lengths):

    # raw_rnn expects time major inputs as TensorArrays
    max_time = ...  # this is the max time step per batch
    inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
    inputs_ta = inputs_ta.unstack(_transpose_batch_time(input_))  # model_input is the input placeholder
    input_dim = input_.get_shape()[-1].value  # the dimensionality of the input to each time step
    output_dim = ...  # the dimensionality of the model's output at each time step

        def loop_fn(time, cell_output, cell_state, loop_state):
            """
            Loop function that allows to control input to the rnn cell and manipulate cell outputs.
            :param time: current time step
            :param cell_output: output from previous time step or None if time == 0
            :param cell_state: cell state from previous time step
            :param loop_state: custom loop state to share information between different iterations of this loop fn
            :return: tuple consisting of
              elements_finished: tensor of size [bach_size] which is True for sequences that have reached their end,
                needed because of variable sequence size
              next_input: input to next time step
              next_cell_state: cell state forwarded to next time step
              emit_output: The first return argument of raw_rnn. This is not necessarily the output of the RNN cell,
                but could e.g. be the output of a dense layer attached to the rnn layer.
              next_loop_state: loop state forwarded to the next time step
            """
            if cell_output is None:
                # time == 0, used for initialization before first call to cell
                next_cell_state = initial_state
                # the emit_output in this case tells TF how future emits look
                emit_output = tf.zeros([output_dim])
            else:
                # t > 0, called right after call to cell, i.e. cell_output is the output from time t-1.
                # here you can do whatever ou want with cell_output before assigning it to emit_output.
                # In this case, we don't do anything
                next_cell_state = cell_state
                emit_output = cell_output  

            # check which elements are finished
            elements_finished = (time >= seq_lengths)
            finished = tf.reduce_all(elements_finished)

            # assemble cell input for upcoming time step
            current_output = emit_output if cell_output is not None else None
            input_original = inputs_ta.read(time)  # tensor of shape (None, input_dim)

            if current_output is None:
                # this is the initial step, i.e. there is no output from a previous time step, what we feed here
                # can highly depend on the data. In this case we just assign the actual input in the first time step.
                next_in = input_original
            else:
                # time > 0, so just use previous output as next input
                # here you could do fancier things, whatever you want to do before passing the data into the rnn cell
                # if here you were to pass input_original than you would get the normal behaviour of dynamic_rnn
                next_in = current_output

            next_input = tf.cond(finished,
                                 lambda: tf.zeros([self.batch_size, input_dim], dtype=tf.float32),  # copy through zeros
                                 lambda: next_in)  # if not finished, feed the previous output as next input

            # set shape manually, otherwise it is not defined for the last dimensions
            next_input.set_shape([None, input_dim])

            # loop state not used in this example
            next_loop_state = None
            return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)

    outputs_ta, last_state, _ = tf.nn.raw_rnn(cell, loop_fn)
    outputs = _transpose_batch_time(outputs_ta.stack())
    final_state = last_state

    return outputs, final_state

As a side note: It is not clear if relying on the model's outputs during training is a good idea. Especially in the beginning, the outputs of the model can be quite bad, so your training might never converge or might not learn anything meaningful.

like image 101
kafman Avatar answered Oct 09 '22 10:10

kafman