Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow - import meta graph and use variables from it

I'm training classification CNN using TensorFlow v0.12, and then want to create labels for new data using the trained model.

At the end of the training script, I added those lines of code:

saver = tf.train.Saver()
save_path = saver.save(sess,'/home/path/to/model/model.ckpt')

After the training completed, the files appearing in the folder are: 1. checkpoint ; 2. model.ckpt.data-00000-of-00001 ; 3. model.ckpt.index ; 4. model.ckpt.meta

Then I tried to restore the model using the .meta file. Following this tutorial, I added the following line into my classification code:

saver=tf.train.import_meta_graph(savepath+'model.ckpt.meta') #line1

and then:

saver.restore(sess, save_path=savepath+'model.ckpt') #line2

Before that change, I needed to build the graph again, and then write (instead of line1):

saver = tf.train.Saver()

But, deleting the graph building, and using line1 in order to restore it, raised an error. The error was that I used a variable from the graph inside my code, and the python didn't recognize it:

predictions = sess.run(y_conv, feed_dict={x: patches,keep_prob: 1.0})

The python didn't recognize the y_conv parameter. There is a way to restore the variables using the meta graph? if not, what os this restore helping, if I can't use variables from the original graph?

I know this question isn't so clear, but it was hard for me to express the problem in words. Sorry about it...

Thanks for answering, appreciate your help! Roi.

like image 402
roishik Avatar asked Feb 06 '17 16:02

roishik


People also ask

What is MetaGraph in TensorFlow?

A MetaGraph contains both a TensorFlow GraphDef as well as associated metadata necessary for running computation in a graph when crossing a process boundary. It can also be used for long term storage of graphs.

What is .PB file in TensorFlow?

The . pb format is the protocol buffer (protobuf) format, and in Tensorflow, this format is used to hold models. Protobufs are a general way to store data by Google that is much nicer to transport, as it compacts the data more efficiently and enforces a structure to the data.

What is a meta graph?

metagraph (plural metagraphs) (mathematics) A graphical representation of a set of objects and the morphisms relating them.


2 Answers

it is possible, don't worry. Assuming you don't want to touch the graph anymore, do something like this:

saver = tf.train.import_meta_graph('model/export/{}.meta'.format(model_name))
saver.restore(sess, 'model/export/{}'.format(model_name))
graph = tf.get_default_graph()       
y_conv = graph.get_operation_by_name('y_conv').outputs[0]
predictions = sess.run(y_conv, feed_dict={x: patches,keep_prob: 1.0})

A preferred way would however be adding the ops into collections when you build the graph and then referring to them. So when you define the graph, you would add the line:

tf.add_to_collection("y_conv", y_conv)

And then after you import the metagraph and restore it, you would call:

y_conv = tf.get_collection("y_conv")[0]

It is actually explained in the documentation - the exact page you linked - but perhaps you missed it.

Btw, no need for the .ckpt extension, it might create some confusion as that is the old way of saving models.

like image 163
Robert Lacok Avatar answered Nov 14 '22 21:11

Robert Lacok


Just to add to Roberts's answer - after obtaining a saver from the meta graph, and using it to restore the variables in the current session, you can also use:

y_conv = graph.get_tensor_by_name('y_conv:0')

This'll work if you've created the y_conv with explicitly adding the name="y_conv" argument (all TF ops have this).

like image 22
Jonan Gueorguiev Avatar answered Nov 14 '22 22:11

Jonan Gueorguiev