The API discusses Graph Collections which judging from the code are a general purpose key/data storage. What is the purpose of those collections?
Graphs are data structures that contain a set of tf. Operation objects, which represent units of computation; and tf. Tensor objects, which represent the units of data that flow between operations. They are defined in a tf. Graph context.
Tensorflow uses dataflow graph to represent computation in terms of the dependencies between individual operations. This leads to a low-level programming model in which one defines the dataflow graph, then creates a TensorFlow session to run parts of the graph across a set of local and remote devices.
A dataflow graph is the representation of a computation where the nodes represent units of computation, and the edges represent the data consumed or produced by the computation. In the context of tf. Graph , every API call defines tf. Operation (node) that can have multiple inputs and outputs tf.
It's simple: A graph defines the computation. It doesn't compute anything, it doesn't hold any values, it just defines the operations that you specified in your code. A session allows to execute graphs or part of graphs.
Remember that under the hood, Tensorflow is a system for specifying and then executing computational data flow graphs. The graph collections are used as part of keeping track of the constructed graphs and how they must be executed. For example, when you create certain kinds of ops, such as tf.train.batch_join
, the code that adds the op will also add some queue runners to the QUEUE_RUNNERS
graph collection. Later, when you call start_queue_runners()
, by default, it will look at the QUEUE_RUNNERS
collection to know which runners to start.
I think there are at least two benefits for me so far:
For instance:
import tensorflow as tf
w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
weight_init_op = tf.variables_initializer(tf.get_collection_ref(tf.GraphKeys.WEIGHTS))
sess = tf.InteractiveSession()
weight_init_op.run()
for vari in tf.get_collection_ref(tf.GraphKeys.WEIGHTS):
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, vari.assign(0.2 * vari))
weight_update_ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
for op in weight_update_ops:
print(op.eval())
The output:
[0.2 0.4 0.6]
[2.2 4.4 6.4]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With