Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to update a subset of 2D tensor in Tensorflow?

I want to update an index in a 2D tensor with value 0. So data is a 2D tensor whose 2nd row 2nd column index value is to be replaced by 0. However, I am getting a type error. Can anyone help me with it?

TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()

sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])
like image 770
user3480922 Avatar asked Dec 11 '22 14:12

user3480922


1 Answers

Scatter update only works on variables. Instead try this pattern.

Tensorflow version < 1.0: a = tf.concat(0, [a[:i], [updated_value], a[i+1:]])

Tensorflow version >= 1.0: a = tf.concat(axis=0, values=[a[:i], [updated_value], a[i+1:]])

like image 185
Yaroslav Bulatov Avatar answered Dec 13 '22 02:12

Yaroslav Bulatov