Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: Finetune pretrained model on new dataset with different number of classes

Tags:

tensorflow

How can I finetune a pretrained model in tensorflow on a new dataset? In Caffe I can simply rename the last layer and set some parameters for random initialization. Is something similar possible in tensorflow?

Say I have a checkpoint file (deeplab_resnet.ckpt) and some code that sets up the computational graph in which I can modify the last layer such that it has the same number of ouputs as the new dataset has classes.

Then I try to start the session like this:

sess = tf.Session(config=config)
init = tf.initialize_all_variables()

sess.run(init)

trainable = tf.trainable_variables()
saver = tf.train.Saver(var_list=trainable, max_to_keep=40)
saver.restore(sess, 'ckpt_path/deeplab_resnet.ckpt')

However this gives me an error when calling the saver.restore function since it expects the exact same graph structure as the the one it was saved from. How can I only load all weights except for the last layer from the 'ckpt_path/deeplab_resnet.ckpt' file? I also tried changing the Classification layer name but no luck there either...

I'm using the tensorflow-deeplab-resnet model

like image 659
mcExchange Avatar asked Jan 19 '17 19:01

mcExchange


People also ask

What is the difference between transfer learning and fine tuning?

Transfer learning is when a model developed for one task is reused to work on a second task. Fine-tuning is one approach to transfer learning where you change the model output to fit the new task and train only the output model. In Transfer Learning or Domain Adaptation, we train the model with a dataset.


1 Answers

You can specify the names of the variables that you want to restore.

So, you can get a list of all of the variables in the model and filter out the variables of the last layer:

all_vars = tf.all_variables()
var_to_restore = [v for v in all_vars if not v.name.startswith('xxx')]

saver = tf.train.Saver(var_to_restore)

See the documentation for the details.

Alternatively, you can try to load the whole model an create a new "branch" out of the layer before the last and use it in the cost function during the training.

like image 75
Alexey Romanov Avatar answered Sep 29 '22 11:09

Alexey Romanov