Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Difference between tf.assign and assignment operator (=)

Tags:

tensorflow

I'm trying to understand the difference between tf.assign and the assignment operator(=). I have three sets of code

First, using simple tf.assign

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  assign_op = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(assign_op)
    print a.eval()
    print a.eval()

The output is expected as

2
2
2

Second, using assignment operator

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = a + 1
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print sess.run(a)
   print a.eval()
   print a.eval()

The results are still 2, 2, 2.

Third, I use both tf.assign and assignment operator

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(a)
    print a.eval()
    print a.eval()

Now, the output becomes 2, 3, 4.

My questions are

  1. In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? Is the op run by the session really just the assign_op? But when I run a.eval(), it doesn't continue to increment a, hence it appears eval is evaluating a "static" variable.

  2. I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?

Thanks.

like image 866
user8490020 Avatar asked Aug 20 '17 06:08

user8490020


People also ask

What is the basic difference between assignment operator and equal to operator?

The '=' is the so-called assignment operator and is used to assign the result of the expression on the right side of the operator to the variable on the left side. The '==' is the so-called equality comparison operator and is used to check whether the two expressions on both sides are equal or not.

What does TF assign do?

tf. assign() is nicely documented in the latest versions and it is used frequently in the projects. This operation outputs "ref" after the assignment is done. This makes it easier to chain operations that need to use the reset value.

Why does R use <- for assignment?

As you all know, R comes from S. But you might not know a lot about S (I don't). This language used <- as an assignment operator. It's partly because it was inspired by a language called APL, which also had this sign for assignment.


2 Answers

First, the anwser is not really precise. IMO, there's no distinguish between python object and tf object. They are all memory objects managed by python GC.

If you change second a to b, and print vars out,

In [2]: g = tf.Graph()

In [3]: with g.as_default():
   ...:     a = tf.Variable(1, name='a')
   ...:     b = a + 1
   ...:

In [4]: print(a)
<tf.Variable 'a:0' shape=() dtype=int32_ref>

In [5]: print(b)
Tensor("add:0", shape=(), dtype=int32)

In [6]: id(a)
Out[6]: 140253111576208

In [7]: id(b)
Out[7]: 140252306449616

a and b are not referring the same object in memory.

Draw the computation graph, or memory graph

first-line,

# a = tf.Varaible(...
a -> var(a)

second line,

# b = a + 1
b -> add - var(a)
      |
       \-- 1

now if you replace it back to your b = a + 1 to a = a + 1, the a after assign operation is pointing to an tf.add object instead of the variable a incremented by 1.

When you run sess.run, you are fetching the result by that add operator with no side effect to the original a variable.

tf.assign, on the other hand, will have the side effect of updating the state of the graph under the session.

like image 159
Izana Avatar answered Oct 21 '22 04:10

Izana


The main confusion here is that doing a = a + 1 will reassign the Python variable a to the resulting tensor of the addition operation a + 1. tf.assign, on the other hand, is an operation for setting the value of a TensorFlow variable.

a = tf.Variable(1, name="a")
a = a + 1

This is equivalent to:

a = tf.add(tf.Variable(1, name="a"), 1)

With that in mind:

In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? [...]

It might look so, but not true. As explained above, this will only reassign the Python variable. And without tf.assign or any other operation that changes the variable, it stays with the value 1. Each time a is evaluated, the program will always calculate a + 1 => 1 + 1.

I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?

That's because calling eval() on the assignment tensor in the third snippet also triggers the variable assignment (note that this isn't much different from doing session.run(a) with the current session).

like image 32
E_net4 stands with Ukraine Avatar answered Oct 21 '22 06:10

E_net4 stands with Ukraine