Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I list all Tensorflow variables a node depends on?

How can I list all Tensorflow variables/constants/placeholders a node depends on?

Example 1 (addition of constants):

import tensorflow as tf

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))

I would like to have a function list_dependencies() such as:

  • list_dependencies(d) returns ['a', 'b']
  • list_dependencies(e) returns ['a', 'b', 'c']

Example 2 (matrix multiplication between a placeholder and a weight matrix, followed by the addition of a bias vector):

tf.set_random_seed(1)
input_size  = 5
output_size = 3
input       = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W           = tf.get_variable(
                "W",
                shape=[input_size, output_size],
                initializer=tf.contrib.layers.xavier_initializer())
b           = tf.get_variable(
                "b",
                shape=[output_size],
                initializer=tf.constant_initializer(2))
output      = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))

I would like to have a function list_dependencies() such as:

  • list_dependencies(output) returns ['W', 'input']
  • list_dependencies(output_bias) returns ['W', 'b', 'input']
like image 397
Franck Dernoncourt Avatar asked Feb 15 '17 18:02

Franck Dernoncourt


People also ask

How do you find the value of TensorFlow variable?

To get the current value of a variable x in TensorFlow 2, you can simply print it with print(x) . This prints a representation of the tf. Variable object that also shows you its current value.

How do I get a TensorFlow tensor name?

Finally, we close the TensorFlow session to release the TensorFlow resources we used within the session. That is how you can get a TensorFlow tensor by name using the tf. get_default_graph operation and then the TensorFlow get_tensor_by_name operation.

What are TensorFlow variables?

A TensorFlow variable is the recommended way to represent shared, persistent state your program manipulates. This guide covers how to create, update, and manage instances of tf. Variable in TensorFlow. Variables are created and tracked via the tf.

What are trainable variables in TensorFlow?

From my understanding, trainable means that the value could be changed during sess.run() That is not the definition of a trainable variable. Any variable can be modified during a sess. run() (That's why they are variables and not constants).


2 Answers

Here are utilities I use for this (from https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}


def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))

These functions work on ops. To get an op that produces tensor t, use t.op. To get tensors produced by op op, use op.outputs

like image 141
Yaroslav Bulatov Avatar answered Sep 29 '22 11:09

Yaroslav Bulatov


Yaroslav Bulatov's answer is great, I'll just add one plotting function that uses Yaroslav's get_graph() and children() method:

import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

plot_graph(get_graph())

Plotting the example 1 from the question:

import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""
  print('get_graph')
  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))
plot_graph(get_graph())

output:

enter image description here

Plotting the example 2 from the question:

enter image description here

If you use Microsoft Windows, you may run into this issue: Python Error (ValueError: _getfullpathname: embedded null character), in which case you need to patch matplotlib as the link explains.

like image 33
Franck Dernoncourt Avatar answered Sep 29 '22 11:09

Franck Dernoncourt