How can I list all Tensorflow variables/constants/placeholders a node depends on?
Example 1 (addition of constants):
import tensorflow as tf
a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')
sess = tf.Session()
print(sess.run([d, e]))
I would like to have a function list_dependencies()
such as:
list_dependencies(d)
returns ['a', 'b']
list_dependencies(e)
returns ['a', 'b', 'c']
Example 2 (matrix multiplication between a placeholder and a weight matrix, followed by the addition of a bias vector):
tf.set_random_seed(1)
input_size = 5
output_size = 3
input = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W = tf.get_variable(
"W",
shape=[input_size, output_size],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(
"b",
shape=[output_size],
initializer=tf.constant_initializer(2))
output = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))
I would like to have a function list_dependencies()
such as:
list_dependencies(output)
returns ['W', 'input']
list_dependencies(output_bias)
returns ['W', 'b', 'input']
To get the current value of a variable x in TensorFlow 2, you can simply print it with print(x) . This prints a representation of the tf. Variable object that also shows you its current value.
Finally, we close the TensorFlow session to release the TensorFlow resources we used within the session. That is how you can get a TensorFlow tensor by name using the tf. get_default_graph operation and then the TensorFlow get_tensor_by_name operation.
A TensorFlow variable is the recommended way to represent shared, persistent state your program manipulates. This guide covers how to create, update, and manage instances of tf. Variable in TensorFlow. Variables are created and tracked via the tf.
From my understanding, trainable means that the value could be changed during sess.run() That is not the definition of a trainable variable. Any variable can be modified during a sess. run() (That's why they are variables and not constants).
Here are utilities I use for this (from https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)
# computation flows from parents to children
def parents(op):
return set(input.op for input in op.inputs)
def children(op):
return set(op for out in op.outputs for op in out.consumers())
def get_graph():
"""Creates dictionary {node: {child1, child2, ..},..} for current
TensorFlow graph. Result is compatible with networkx/toposort"""
ops = tf.get_default_graph().get_operations()
return {op: children(op) for op in ops}
def print_tf_graph(graph):
"""Prints tensorflow graph in dictionary form."""
for node in graph:
for child in graph[node]:
print("%s -> %s" % (node.name, child.name))
These functions work on ops. To get an op that produces tensor t
, use t.op
. To get tensors produced by op op
, use op.outputs
Yaroslav Bulatov's answer is great, I'll just add one plotting function that uses Yaroslav's get_graph()
and children()
method:
import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(G):
'''Plot a DAG using NetworkX'''
def mapping(node):
return node.name
G = nx.DiGraph(G)
nx.relabel_nodes(G, mapping, copy=False)
nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
plt.show()
plot_graph(get_graph())
Plotting the example 1 from the question:
import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf
def children(op):
return set(op for out in op.outputs for op in out.consumers())
def get_graph():
"""Creates dictionary {node: {child1, child2, ..},..} for current
TensorFlow graph. Result is compatible with networkx/toposort"""
print('get_graph')
ops = tf.get_default_graph().get_operations()
return {op: children(op) for op in ops}
def plot_graph(G):
'''Plot a DAG using NetworkX'''
def mapping(node):
return node.name
G = nx.DiGraph(G)
nx.relabel_nodes(G, mapping, copy=False)
nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
plt.show()
a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')
sess = tf.Session()
print(sess.run([d, e]))
plot_graph(get_graph())
output:
Plotting the example 2 from the question:
If you use Microsoft Windows, you may run into this issue: Python Error (ValueError: _getfullpathname: embedded null character), in which case you need to patch matplotlib as the link explains.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With