Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Build a graph that works with variable batch size using Tensorflow

I use tf.placeholders() ops for feeding variable batch size of input which are 2D tensors and use the feed mechanism to supply different values for those tensors when I call run(). I got

TypeError: 'Tensor' object is not iterable.

Following is my code:

with graph.as_default():
    train_index_input = tf.placeholder(tf.int32, shape=(None, window_size))
    train_embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_dimension], -1.0, 1.0))
    embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]
    ......
    ......

As I can not see the content of the tensor "train_index_input" without running the graph, so the error of "'Tensor' object is not iterable" raises for the code:

embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]

What I want to obtain is an embedding matrix "embedding_input" that is of shape [batch_size, embedding_dimension] where batch_size is not fixed. Do I have to define a new operation in Tensorflow for embedding lookup for 2D tensors? Or any other ways to do it?Thanks

like image 905
hgyp Avatar asked Nov 24 '15 03:11

hgyp


1 Answers

You're trying to do a python-level list comprehension (for x in train_index_input) over a Tensorflow placeholder. That won't work - Python has no idea what's inside a tf object.

To accomplish a batch embedding lookup, what you can do is just flatten your batch:

train_indexes_flat = tf.reshape(train_index_input, [-1])

run it through embedding lookup:

looked_up_embeddings = tf.nn.embedding_lookup(train_embeddings, train_indexes_flat)

and then reshape it back into the right groups:

embedding_input = tf.reshape(looked_up_embeddings, [-1, window_size])
like image 85
dga Avatar answered Oct 22 '22 04:10

dga