Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I convert a trained Tensorflow model to Keras?

Tags:

I have a trained Tensorflow model and weights vector which have been exported to protobuf and weights files respectively.

How can I convert these to JSON or YAML and HDF5 files which can be used by Keras?

I have the code for the Tensorflow model, so it would also be acceptable to convert the tf.Session to a keras model and save that in code.

like image 686
Matt Avatar asked Jun 09 '17 20:06

Matt


People also ask

How do you convert a trained model to TFLite?

To convert a trained TensorFlow model to run on microcontrollers, you should use the TensorFlow Lite converter Python API. This will convert the model into a FlatBuffer , reducing the model size, and modify it to use TensorFlow Lite operations.


2 Answers

I think the callback in keras is also a solution.

The ckpt file can be saved by TF with:

saver = tf.train.Saver() saver.save(sess, checkpoint_name) 

and to load checkpoint in Keras, you need a callback class as follow:

class RestoreCkptCallback(keras.callbacks.Callback):     def __init__(self, pretrained_file):         self.pretrained_file = pretrained_file         self.sess = keras.backend.get_session()         self.saver = tf.train.Saver()     def on_train_begin(self, logs=None):         if self.pretrian_model_path:             self.saver.restore(self.sess, self.pretrian_model_path)             print('load weights: OK.') 

Then in your keras script:

 model.compile(loss='categorical_crossentropy', optimizer='rmsprop')  restore_ckpt_callback = RestoreCkptCallback(pretrian_model_path='./XXXX.ckpt')   model.fit(x_train, y_train, batch_size=128, epochs=20, callbacks=[restore_ckpt_callback]) 

That will be fine. I think it is easy to implement and hope it helps.

like image 60
Jiang Xiang Avatar answered Oct 21 '22 17:10

Jiang Xiang


Francois Chollet, the creator of keras, stated in 04/2017 "you cannot turn an arbitrary TensorFlow checkpoint into a Keras model. What you can do, however, is build an equivalent Keras model then load into this Keras model the weights" , see https://github.com/keras-team/keras/issues/5273 . To my knowledge this hasn't changed.

A small example:

First, you can extract the weights of a tensorflow checkpoint like this

PATH_REL_META = r'checkpoint1.meta'      # start tensorflow session with tf.Session() as sess:          # import graph     saver = tf.train.import_meta_graph(PATH_REL_META)          # load weights for graph     saver.restore(sess, PATH_REL_META[:-5])              # get all global variables (including model variables)     vars_global = tf.global_variables()          # get their name and value and put them into dictionary     sess.as_default()     model_vars = {}     for var in vars_global:         try:             model_vars[var.name] = var.eval()         except:             print("For var={}, an exception occurred".format(var.name)) 

It might also be of use to export the tensorflow model for use in tensorboard, see https://stackoverflow.com/a/43569991/2135504

Second, you build you keras model as usually and finalize it by "model.compile". Pay attention that you need to give you define each layer by name and add it to the model after that, e.g.

layer_1 = keras.layers.Conv2D(6, (7,7), activation='relu', input_shape=(48,48,1)) net.add(layer_1) ... net.compile(...) 

Third, you can set the weights with the tensorflow values, e.g.

layer_1.set_weights([model_vars['conv7x7x1_1/kernel:0'], model_vars['conv7x7x1_1/bias:0']]) 
like image 37
gebbissimo Avatar answered Oct 21 '22 18:10

gebbissimo