I'm trying to understand the difference between tf.assign and the assignment operator(=). I have three sets of code
First, using simple tf.assign
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
assign_op = tf.assign(a, tf.add(a,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(assign_op)
print a.eval()
print a.eval()
The output is expected as
2
2
2
Second, using assignment operator
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
a = a + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(a)
print a.eval()
print a.eval()
The results are still 2, 2, 2.
Third, I use both tf.assign and assignment operator
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
a = tf.assign(a, tf.add(a,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(a)
print a.eval()
print a.eval()
Now, the output becomes 2, 3, 4.
My questions are
In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? Is the op run by the session really just the assign_op? But when I run a.eval(), it doesn't continue to increment a, hence it appears eval is evaluating a "static" variable.
I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?
Thanks.
The '=' is the so-called assignment operator and is used to assign the result of the expression on the right side of the operator to the variable on the left side. The '==' is the so-called equality comparison operator and is used to check whether the two expressions on both sides are equal or not.
tf. assign() is nicely documented in the latest versions and it is used frequently in the projects. This operation outputs "ref" after the assignment is done. This makes it easier to chain operations that need to use the reset value.
As you all know, R comes from S. But you might not know a lot about S (I don't). This language used <- as an assignment operator. It's partly because it was inspired by a language called APL, which also had this sign for assignment.
First, the anwser is not really precise. IMO, there's no distinguish between python object and tf object. They are all memory objects managed by python GC.
If you change second a
to b
, and print vars out,
In [2]: g = tf.Graph()
In [3]: with g.as_default():
...: a = tf.Variable(1, name='a')
...: b = a + 1
...:
In [4]: print(a)
<tf.Variable 'a:0' shape=() dtype=int32_ref>
In [5]: print(b)
Tensor("add:0", shape=(), dtype=int32)
In [6]: id(a)
Out[6]: 140253111576208
In [7]: id(b)
Out[7]: 140252306449616
a
and b
are not referring the same object in memory.
Draw the computation graph, or memory graph
first-line,
# a = tf.Varaible(...
a -> var(a)
second line,
# b = a + 1
b -> add - var(a)
|
\-- 1
now if you replace it back to your b = a + 1
to a = a + 1
, the a
after assign operation is pointing to an tf.add
object instead of the variable a
incremented by 1.
When you run sess.run
, you are fetching the result by that add
operator with no side effect to the original a
variable.
tf.assign
, on the other hand, will have the side effect of updating the state of the graph under the session.
The main confusion here is that doing a = a + 1
will reassign the Python variable a
to the resulting tensor of the addition operation a + 1
. tf.assign
, on the other hand, is an operation for setting the value of a TensorFlow variable.
a = tf.Variable(1, name="a")
a = a + 1
This is equivalent to:
a = tf.add(tf.Variable(1, name="a"), 1)
With that in mind:
In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? [...]
It might look so, but not true. As explained above, this will only reassign the Python variable. And without tf.assign
or any other operation that changes the variable, it stays with the value 1. Each time a
is evaluated, the program will always calculate a + 1 => 1 + 1
.
I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?
That's because calling eval()
on the assignment tensor in the third snippet also triggers the variable assignment (note that this isn't much different from doing session.run(a)
with the current session).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With