Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get weights from tensorflow model

Hello I would like to finetune VGG model from tensorflow. I have two questions.

How to get the weights from network? The trainable_variables returns empty list for me.

I used existing model from here: https://github.com/ry/tensorflow-vgg16 . I find the post about getting weights however this doesn't work for me because of import_graph_def. Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf
import PIL.Image
import numpy as np

with open("../vgg16.tfmodel", mode='rb') as f:
  fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print("graph loaded from disk")

graph = tf.get_default_graph()

cat = np.asarray(PIL.Image.open('../cat224.jpg'))
print(cat.shape)
init = tf.initialize_all_variables()

with tf.Session(graph=graph) as sess:
  print(tf.trainable_variables() )
  sess.run(init)
like image 355
Cospel Avatar asked Apr 04 '16 20:04

Cospel


People also ask

What is weight in TensorFlow?

In TensorFlow, trained weights are represented by tf. Variable objects. If you created a tf. Variable —e.g. called v —yourself, you can get its value as a NumPy array by calling sess. run(v) (where sess is a tf.

How do you get weights in TF layers dense?

Dense(...) . Once you have a handle to this layer object, you can use all of its functionality. For obtaining the weights, just use obj. trainable_weights this returns a list of all the trainable variables found in that layer's scope.

How do I get a model summary in TensorFlow?

Model summaryCall model. summary() to print a useful summary of the model, which includes: Name and type of all layers in the model. Output shape for each layer.


1 Answers

This pretrained VGG-16 model encodes all of the model parameters as tf.constant() ops. (See, for example, the calls to tf.constant() here.) As a result, the model parameters would not appear in tf.trainable_variables(), and the model is not mutable without substantial surgery: you would need to replace the constant nodes with tf.Variable objects that start with the same value in order to continue training.

In general, when importing a graph for retraining, the tf.train.import_meta_graph() function should be used, as this function loads additional metadata (including the collections of variables). The tf.import_graph_def() function is lower level, and does not populate these collections.

like image 161
mrry Avatar answered Oct 08 '22 12:10

mrry