Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to Increment a Variable in Tensorflow?

When trying to use the supervisor in Tensorflow I was made aware that :

your training op is responsible for incrementing the global step value.

(Reference)

So how do you increment a variable in a graph in Tensorflow?

like image 405
MattCochrane Avatar asked Nov 27 '22 08:11

MattCochrane


1 Answers

Pretty simple solution:

global_step = tf.Variable(1, name='global_step', trainable=False, dtype=tf.int32)
increment_global_step_op = tf.assign(global_step, global_step+1)

Then when you want to increment it, just run that op under the current tf.Session sess.

step = sess.run(increment_global_step_op)

The result placed in step is the value of the incremented variable after the increment. In this case, the value of global_step after being incremented. So 2.

If you're using this for global_step like me, run it along with your training_op.

result = sess.run([out, increment_global_step_op], {x: [i]})
like image 132
MattCochrane Avatar answered Dec 10 '22 10:12

MattCochrane