Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to read weights saved in tensorflow checkpoint file?

Tags:

tensorflow

I'd like to read the weights and visualize them as images. But I don't see any documentation about model format and how to read the trained weights.

like image 867
mr49 Avatar asked Oct 18 '16 21:10

mr49


People also ask

What is checkpoint file in TensorFlow?

b) Checkpoint file: This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. However, Tensorflow has changed this from version 0.11.


2 Answers

There's this utility which has on print_tensors_in_checkpoint_file method http://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py

Alternatively, you can use Saver to restore the model and use session.run on variable tensors to get values as numpy arrays

like image 154
Yaroslav Bulatov Avatar answered Sep 25 '22 12:09

Yaroslav Bulatov


I wrote snippet in Python

def extracting(meta_dir):
    num_tensor = 0
    var_name = ['2-convolutional/kernel']
    model_name = meta_dir
    configfiles = [os.path.join(dirpath, f)  
        for dirpath, dirnames, files in os.walk(model_name)
        for f in fnmatch.filter(files, '*.meta')] # List of META files

    with tf.Session() as sess:
        try:
            # A MetaGraph contains both a TensorFlow GraphDef
            # as well as associated metadata necessary
            # for running computation in a graph when crossing a process boundary.
            saver = tf.train.import_meta_graph(configfiles[0])
       except:
           print("Unexpected error:", sys.exc_info()[0])
       else:
           # It will get the latest check point in the directory
           saver.restore(sess, configfiles[-1].split('.')[0])  # Specific spot

           # Now, let's access and create placeholders variables and
           # create feed-dict to feed new data
           graph = tf.get_default_graph()
           inside_list = [n.name for n in graph.as_graph_def().node]

           print('Step: ', configfiles[-1])

           print('Tensor:', var_name[0] + ':0')
           w2 = graph.get_tensor_by_name(var_name[0] + ':0')
           print('Tensor shape: ', w2.get_shape())
           print('Tensor value: ', sess.run(w2))
           w2_saved = sess.run(w2)  # print out tensor

You could run it by giving meta_dir as your pre-trained model directory.

like image 35
Cloud Cho Avatar answered Sep 22 '22 12:09

Cloud Cho