How can I list all Tensorflow variables a node depends on?

How can I list all Tensorflow variables/constants/placeholders a node depends on?

Example 1 (addition of constants):

import tensorflow as tf

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))

I would like to have a function list_dependencies() such as:

  • list_dependencies(d) returns ['a', 'b']
  • list_dependencies(e) returns ['a', 'b', 'c']

Example 2 (matrix multiplication between a placeholder and a weight matrix, followed by the addition of a bias vector):

input_size  = 5
output_size = 3
input       = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W           = tf.get_variable(
                shape=[input_size, output_size],
b           = tf.get_variable(
output      = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")

sess = tf.Session()
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))

I would like to have a function list_dependencies() such as:

  • list_dependencies(output) returns ['W', 'input']
  • list_dependencies(output_bias) returns ['W', 'b', 'input']
2 Answers

Here are utilities I use for this (from https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))

These functions work on ops. To get an op that produces tensor t, use t.op. To get tensors produced by op op, use op.outputs

Yaroslav Bulatov's answer is great, I'll just add one plotting function that uses Yaroslav's get_graph() and children() method:

import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)


Plotting the example 1 from the question:

import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""
  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))


Plotting the example 2 from the question:

If you use Microsoft Windows, you may run into this issue: Python Error (ValueError: _getfullpathname: embedded null character), in which case you need to patch matplotlib as the link explains.

