Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

API Reference for RNN and Seq2Seq models in tensorflow

Where can I find the API references that specifies the available functions in the RNN and Seq2Seq models.

In the github page it was mentioned that rnn and seq2seq were moved to tf.nn

like image 752
Aravind Pilla Avatar asked Feb 08 '23 01:02

Aravind Pilla


1 Answers

[NOTE: this answer is updated for r1.0 ... but explains legacy_seq2seq instead of tensorflow/tensorflow/contrib/seq2seq/]

The good news is that the seq2seq models provided in tensorflow are pretty sophisticated including embeddings, buckets, attention mechanism, one-to-many multi-task models, etc.

The bad news is that there is much complexity and layers of abstraction in the Python code, and that the code itself is the best available "documentation" of the higher-level RNN and seq2seq "API" as far as I can tell...thankfully the code is well docstring'd.

Practically speaking I think the examples and helper functions pointed to below are mainly useful for reference to understand coding patterns...and that in most cases you'll need to re-implement what you need using the basic functions in the lower-level Python API

Here is a breakdown of the RNN seq2seq code from the top down as of version r1.0:

models/tutorials/rnn/translate/translate.py

...provides main(), train(), decode() that works out-of-the-box to translate english to french...but you can adapt this code to other data sets

models/tutorials/rnn/translate/seq2seq_model.py

...class Seq2SeqModel() sets up a sophisticated RNN encoder-decoder with embeddings, buckets, attention mechanism...if you don't need embeddings, buckets, or attention you'll need to implement a similar class.

tensorflow/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py

...main entry point for seq2seq models via helper functions. See model_with_buckets(), embedding_attention_seq2seq(), embedding_attention_decoder(), attention_decoder(), sequence_loss(), etc. Examples include one2many_rnn_seq2seq and models without embeddings/attention also provided like basic_rnn_seq2seq. If you can jam your data into the tensors that these functions will accept this could be your best entry point to building your own model.

tensorflow/tensorflow/contrib/rnn/python/ops/core_rnn.py

...provides a wrappers for RNN networks like static_rnn() with some bell and whistles I usually don't need so I just use code like this instead:

def simple_rnn(cell, inputs, dtype, score):
    with variable_scope.variable_scope(scope or "simple_RNN") as varscope1:
            if varscope1.caching_device is None:
                varscope1.set_caching_device(lambda op: op.device)

        batch_size = array_ops.shape(inputs[0])[0]
        outputs = []
        state = cell.zero_state(batch_size, dtype)            

        for time, input_t in enumerate(inputs):
           if time > 0:      
             variable_scope.get_variable_scope().reuse_variables()


           (output, state) = cell(input_t, state)

           outputs.append(output)

        return outputs, state
like image 98
j314erre Avatar answered Mar 03 '23 18:03

j314erre