Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

open tensorflow graph from file

I'm trying to use tensorflow for study and i don't undestand how to open and use saved early in file my graph with type tf.Graph. Something like this:

import tensorflow as tf

my_graph = tf.Graph()

with g.as_default():
    x = tf.Variable(0)
    b = tf.constant(-5)
    k = tf.constant(2)

    y = k*x + b

tf.train.write_graph(my_graph, '.', 'graph.pbtxt')

f = open('graph.pbtxt', "r")

# Do something with "f" to get my saved graph and use it below in
# tf.Session(graph=...) instead of dots

with tf.Session(graph=...) as sess:
    tf.initialize_all_variables().run()

    y1 = sess.run(y, feed_dict={x: 5})
    y2 = sess.run(y, feed_dict={x: 10})
    print(y1, y2)
like image 965
Sergey Avatar asked Oct 28 '16 13:10

Sergey


1 Answers

You have to load file contents, parse it to GraphDef and then import. It will be imported into current graph. You may want to wrap it with graph.as_default(): context manager.

import tensorflow as tf
from tensorflow.core.framework import graph_pb2 as gpb
from google.protobuf import text_format as pbtf

gdef = gpb.GraphDef()

with open('my-graph.pbtxt', 'r') as fh:
    graph_str = fh.read()

pbtf.Parse(graph_str, gdef)

tf.import_graph_def(gdef)
like image 168
dm0_ Avatar answered Sep 20 '22 16:09

dm0_