Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: Creating a graph in a class and running it outside

I believe I am having a hard time understanding how graphs work in tensorflow and how to access them. My intuition is that the lines under 'with graph:' will form the graph as a single entity. Hence, I decided to create a class that would build a graph when instantiated and would possess a function that would run the graph, as following ;

class Graph(object):

    #To build the graph when instantiated
    def __init__(self, parameters ):
        self.graph = tf.Graph()
        with self.graph.as_default():
             ...
             prediction = ... 
             cost       = ...
             optimizer  = ...
             ...
    # To launch the graph
    def launchG(self, inputs):
        with tf.Session(graph=self.graph) as sess:
             ...
             sess.run(optimizer, feed_dict)
             loss = sess.run(cost, feed_dict)
             ...
        return variables

The next steps are to create a main file that will assemble the parameters to pass to the class, to build the graph and then to run it;

#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }

#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...

This is very elegant to me, but it doesn't quite work (obviously). Indeed, it seems like the launchG functions does not have access to the nodes defined in the graph, which give me error such as ;

---> 26 sess.run(optimizer, feed_dict)

NameError: name 'optimizer' is not defined

Perhaps it is my python (and tensorflow) understanding that is too limited, but I was under the strange impression that with the graph(G) created, running the session with this graph as an argument should give access to the nodes in it, without requiring me to give explicit access.

Any enlightenment?

like image 885
PhABC Avatar asked Jun 12 '16 05:06

PhABC


People also ask

How graphs are stored and represented in TensorFlow?

TensorFlow uses graphs as the format for saved models when it exports them from Python. Graphs are also easily optimized, allowing the compiler to do transformations like: Statically infer the value of tensors by folding constant nodes in your computation ("constant folding").

Does TensorFlow creates a computational graph?

In TensorFlow, machine learning algorithms are represented as computational graphs. A computational graph is a type of directed graph where nodes describe operations, while edges represent the data (tensor) flowing between those operations.

What is retracing TensorFlow?

Retracing, which is when your Function creates more than one trace, helps ensures that TensorFlow generates correct graphs for each set of inputs.


1 Answers

The nodes prediction, cost, and optimizer are local variables created in the method __init__, they cannot be accessed in the method launchG.

The easiest fix would be to declare them as attributes of your class Graph:

class Graph(object):

    #To build the graph when instantiated
    def __init__(self, parameters ):
        self.graph = tf.Graph()
        with self.graph.as_default():
             ...
             self.prediction = ... 
             self.cost       = ...
             self.optimizer  = ...
             ...
    # To launch the graph
    def launchG(self, inputs):
        with tf.Session(graph=self.graph) as sess:
             ...
             sess.run(self.optimizer, feed_dict)
             loss = sess.run(self.cost, feed_dict)
             ...
        return variables

You can also retrieve the nodes of the graph using their exact name with graph.get_tensor_by_name and graph.get_operation_by_name.

like image 139
Olivier Moindrot Avatar answered Oct 19 '22 23:10

Olivier Moindrot