Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow assign Tensor to Tensor with array indexing

I would like to do something like this piece of Numpy code, just in TensorFlow:

a = np.zeros([5, 2])
idx = np.random.randint(0, 2, (5,))
row_idx = np.arange(5)
a[row_idx, idx] = row_idx

meaning indexing all rows of a 2D tensor with another tensor and then assigning a tensor to that. I am absolutely clueless on how to achieve this.

What I can do so far in Tensorflow is the following

a = tf.Variable(tf.zeros((5, 2)))
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.transpose([row_idx, idx])
r = tf.gather_nd(a, indices)
tf.assign(r, row_idx) # This line does not work

When I try to execute this, I get the following error in the last line:

AttributeError: 'Tensor' object has no attribute 'assign'

Is there a way around this? There must be some nice way to do this, I don't want to iterate with for loops over the data and manually assign this on a per-element basis. I know that right now array-indexing is not as advanced as Numpy's functionality, but this should still be possible somehow.

like image 830
Pethor Avatar asked Dec 05 '18 12:12

Pethor


1 Answers

What you are trying to do is frequently done with tf.scatter_nd_update. However, that is most times not the right way to do it, you should not need a variable, just another tensor produced from the original tensor with some replaced values. Unfortunately, there is no straightforward way to do this in general. If your original tensor is really all zeros, then you can simply use tf.scatter_nd:

import tensorflow as tf

idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.stack([row_idx, idx], axis=1)
a = tf.scatter_nd(indices, row_idx, (5, 2))
with tf.Session() as sess:
    print(sess.run(a))
# [[0 0]
#  [0 1]
#  [0 2]
#  [3 0]
#  [0 4]]

However, if the initial tensor is not all zeros, it is more complicated. One way to do that is do the same as above, then make a mask for the updated, and select between the original and the update according to the mask:

import tensorflow as tf

a = tf.ones((5, 2), dtype=tf.int32)
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.stack([row_idx, idx], axis=1)
a_update = tf.scatter_nd(indices, row_idx, (5, 2))
update_mask = tf.scatter_nd(indices, tf.ones_like(row_idx, dtype=tf.bool), (5, 2))
a = tf.where(update_mask, a_update, a)
with tf.Session() as sess:
    print(sess.run(a))
# [[0 1]
#  [1 1]
#  [1 2]
#  [3 1]
#  [1 4]]
like image 194
jdehesa Avatar answered Nov 04 '22 20:11

jdehesa