Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow slicing based on variable

I've found that indexing still is an open issue in tensorflow (#206), so I'm wondering what I could use as a workaround at the moment. I want to index/slice a row/column of a matrix based on a variable that changes for every training example.

What I've tried so far:

  1. Slicing based on placeholder (doesn't work)

The following (working) code slices based on a fixed number.

import tensorflow as tf
import numpy as np

x = tf.placeholder("float")
y = tf.slice(x,[0],[1])

#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

#run
result = sess.run(y, feed_dict={x:[1,2,3,4,5]})
print(result)

However, it seems that I can't simply replace one of these fixed numbers with a tf.placeholder. The following code gives me the error "TypeError: List of Tensors when single Tensor expected."

import tensorflow as tf
import numpy as np

x = tf.placeholder("float")
i = tf.placeholder("int32")
y = tf.slice(x,[i],[1])

#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

#run
result = sess.run(y, feed_dict={x:[1,2,3,4,5],i:0})
print(result)

This sounds like the brackets around [i] are too much, but removing them doesn't help either. How to use a placeholder/variable as index?

  1. Slicing based on python variable (doesn't backprop/update properly)

I've also tried using a normal python variable as index. This does not lead to an error, but the network doesn't learn anything while training. I suppose because the changing variable is not properly registered, the graph is malformed and updates don't work?

  1. Slicing via one-hot vector + multiplication (works, but is slow)

One workaround I found is using a one-hot vector. Making a one-hot vector in numpy, passing this using a placeholder, then doing the slicing via matrix multiplication. This works, but is quite slow.

Any ideas how to efficiently slice/index based on a variable?

like image 509
Daniela Avatar asked Nov 30 '15 15:11

Daniela


1 Answers

Slicing based on a placeholder should work just fine. It looks like you are running into a type error, due to some subtle issues of shapes and types. Where you have the following:

x = tf.placeholder("float")
i = tf.placeholder("int32")
y = tf.slice(x,[i],[1])

...you should instead have:

x = tf.placeholder("float")
i = tf.placeholder("int32")
y = tf.slice(x,i,[1])

...and then you should feed i as [0] in the call to sess.run().

To make this a little clearer, I would recommend rewriting the code as follows:

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, shape=[None])  # 1-D tensor
i = tf.placeholder(tf.int32, shape=[1])
y = tf.slice(x, i, [1])

#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

#run
result = sess.run(y, feed_dict={x: [1, 2, 3, 4, 5], i: [0]})
print(result)

The additional shape arguments to the tf.placeholder op help to ensure that the values you feed have the appropriate shapes, and also that TensorFlow will raise an error if the shapes are not correct.

like image 136
mrry Avatar answered Sep 18 '22 12:09

mrry