Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I use the index array in tensorflow?

If given a matrix a with shape (5,3) and index array b with shape (5,), we can easily get the corresponding vector c through,

c = a[np.arange(5), b]

However, I cannot do the same thing with tensorflow,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]

Traceback (most recent call last): File "", line 1, in File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 513, in _SliceHelper name=name)

File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 671, in strided_slice shrink_axis_mask=shrink_axis_mask) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3688, in strided_slice shrink_axis_mask=shrink_axis_mask, name=name) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op op_def=op_def) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2397, in create_op set_shapes_for_outputs(ret) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1757, in set_shapes_for_outputs shapes = shape_func(op) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1707, in call_with_requiring return call_cpp_shape_fn(op, require_shape_fn=True) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn debug_python_shape_fn, require_shape_fn) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_14' (op: 'StridedSlice') with input shapes: [5,3], [2,5], [2,5], [2].

My question is, if I cannot produce the expected result in tensorflow as in numpy using the above mentioned method, what should I do?

like image 846
xxx222 Avatar asked Mar 06 '17 10:03

xxx222


1 Answers

This feature is not currently implemented in TensorFlow. GitHub issue #4638 is tracking the implementation of NumPy-style "advanced" indexing. However, you can use the tf.gather_nd() operator to implement your program:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))

row_indices = tf.range(5)

# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])

c = tf.gather_nd(a, indices)
like image 140
mrry Avatar answered Oct 17 '22 04:10

mrry