Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Deleting all but a few nodes in TensorFlow graph

My use-case of TensorFlow requires me to build a new computation graph for each instance that needs to be processed. This ends up blowing up the memory requirements.

Apart from a few tf.Variables that are model parameters, I'd like to delete all other nodes. Other people with similar problems have found tf.reset_default_graph() to be useful, but this would get rid of the model parameters that I need to persist.

What can I use to delete all but these nodes?

Edit: The instance specific computations actually just means I am adding a lot new operations. I believe these operations are the reason behind the memory issues.

UPDATE: See the recently released tensorflow fold (https://github.com/tensorflow/fold) which allows dynamic construction of computation graphs.

like image 786
navari Avatar asked May 08 '16 17:05

navari


People also ask

Which function is used for removing node from graph in Python?

remove_node. Remove node n. Removes the node n and all adjacent edges.

What is TF Reset_default_graph ()?

tf. reset_default_graph() Defined in tensorflow/python/framework/ops.py . See the guide: Building Graphs > Utility functions. Clears the default graph stack and resets the global default graph.

What is default graph in TensorFlow?

A graph is like a TODO: list. You may use more than one graphs (created with tf. Graph() in the same process, but one is the default. Note you will have to use different sessions for each graph, but each graph can be used in multiple sessions. Even more, a session allows executing graphs or part of graphs.


1 Answers

The tf.graph data-structure is designed to be an append-only data-structure. It is therefore not possible to remove or modify existing nodes. Usually this is not a problem, as only the necessary subgraph is processed when running a session.

What you can try is to copy the Variabels of your graph into a new graph and delete the old one. To archive this just run:

old_graph = tf.get_default_graph() # Save the old graph for later iteration
new_graph = tf.graph() # Create an empty graph
new_graph.set_default() # Makes the new graph default

If you want to iterate over all nodes in the old graph use:

for node in old_graph.get_operations():
    if node.type == 'Variable':
       # read value of variable and copy it into new Graph

Alternatively you can use:

for node in old_graph.get_collection('trainable_variables'):
   # iterates over all trainable Variabels
   # read and create new variable

Have also a look at python/framework/ops.py : 1759 to see more ways on manipulating nodes in graph.

However before you mess around with tf.Graph I would strongly recommend to consider whether this is really required. Usually one can try to generalize the computation and use shared variables build a graph, so that each instance you want to process is a subgraph of this graph.

like image 68
MarvMind Avatar answered Oct 17 '22 02:10

MarvMind