I would like to do something like this piece of Numpy code, just in TensorFlow:
a = np.zeros([5, 2])
idx = np.random.randint(0, 2, (5,))
row_idx = np.arange(5)
a[row_idx, idx] = row_idx
meaning indexing all rows of a 2D tensor with another tensor and then assigning a tensor to that. I am absolutely clueless on how to achieve this.
What I can do so far in Tensorflow is the following
a = tf.Variable(tf.zeros((5, 2)))
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.transpose([row_idx, idx])
r = tf.gather_nd(a, indices)
tf.assign(r, row_idx) # This line does not work
When I try to execute this, I get the following error in the last line:
AttributeError: 'Tensor' object has no attribute 'assign'
Is there a way around this? There must be some nice way to do this, I don't want to iterate with for loops over the data and manually assign this on a per-element basis. I know that right now array-indexing is not as advanced as Numpy's functionality, but this should still be possible somehow.
What you are trying to do is frequently done with tf.scatter_nd_update
. However, that is most times not the right way to do it, you should not need a variable, just another tensor produced from the original tensor with some replaced values. Unfortunately, there is no straightforward way to do this in general. If your original tensor is really all zeros, then you can simply use tf.scatter_nd
:
import tensorflow as tf
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.stack([row_idx, idx], axis=1)
a = tf.scatter_nd(indices, row_idx, (5, 2))
with tf.Session() as sess:
print(sess.run(a))
# [[0 0]
# [0 1]
# [0 2]
# [3 0]
# [0 4]]
However, if the initial tensor is not all zeros, it is more complicated. One way to do that is do the same as above, then make a mask for the updated, and select between the original and the update according to the mask:
import tensorflow as tf
a = tf.ones((5, 2), dtype=tf.int32)
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.stack([row_idx, idx], axis=1)
a_update = tf.scatter_nd(indices, row_idx, (5, 2))
update_mask = tf.scatter_nd(indices, tf.ones_like(row_idx, dtype=tf.bool), (5, 2))
a = tf.where(update_mask, a_update, a)
with tf.Session() as sess:
print(sess.run(a))
# [[0 1]
# [1 1]
# [1 2]
# [3 1]
# [1 4]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With