Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Retraining the last layer of Inception-ResNet-v2

I am trying to retrain the last layer of inception-resnet-v2. Here's what I came up with:

  1. Get names of variables in the final layer
  2. Create a train_op to minimise only these variables wrt loss
  3. Restore the whole graph except the final layer while initialising only the last layer randomly.

And I implemented that as follows:

with slim.arg_scope(arg_scope):
    logits = model(images_ph, is_training=True, reuse=None)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_ph))
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, 1), labels_ph)

train_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'InceptionResnetV2/Logits')
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)

train_op = optimizer.minimize(loss, var_list=train_list)

# restore all variables whose names doesn't contain 'logits'
restore_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='^((?!Logits).)*$')

saver = tf.train.Saver(restore_list, write_version=tf.train.SaverDef.V2)

with tf.Session() as session:


    init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())

    session.run(init_op)
    saver.restore(session, '../models/inception_resnet_v2_2016_08_30.ckpt')


# followed by code for running train_op

This doesn't seem to work (training loss, error don't improve much from initial values). Is there a better/elegant way to do this? It would be good learning for me if you can also tell me what's going wrong here.

like image 633
Priyatham Avatar asked Dec 31 '16 09:12

Priyatham


People also ask

How many layers does inception-ResNet-v2 have?

Inception-ResNet-v2 is a convolutional neural network that is trained on more than a million images from the ImageNet database [1]. The network is 164 layers deep and can classify images into 1000 object categories, such as keyboard, mouse, pencil, and many animals.

Which is better inception or ResNet?

While Inception focuses on computational cost, ResNet focuses on computational accuracy. Intuitively, deeper networks should not perform worse than the shallower networks, but in practice, the deeper networks performed worse than the shallower networks, caused not by overfitting, but by an optimization problem.

What is Inception ResNet v1?

● Inception-ResNet-v1: a hybrid Inception version that has a similar computational. cost to Inception-v3. ● Inception-ResNet-v2: a costlier hybrid Inception version with significantly. improved recognition performance.

What is ResNetv2?

The ResNetv2 is a neural network architecture used for image classification, regression and feature extraction. It uses skip connections to add the input of a group of convolutions to its output.


1 Answers

There are several things:

  • how is the learning rate? a too high value can mess with everything (probably not the reason)
  • try to use stochastic gradient descent, you should have less problems
  • is the scope correctly set? if you don't use L2 regularization and batch normalization of the gradients you might fall into a local minimum very soon and the network is unable to learn

    from nets import inception_resnet_v2 as net
    with net.inception_resnet_v2_arg_scope():
        logits, end_points = net.inception_resnet_v2(images_ph, num_classes=num_classes,
                                                     is_training=True)
    
  • you should add the regularization variables to the loss (or at least the ones of the last layer):

    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    all_losses = [loss] + regularization_losses
    total_loss = tf.add_n(all_losses, name='total_loss')
    
  • training only the full connected layer might not be a good idea, I would train all the network as the features you need for your class aren't necessarily defined in the last layer but few layers before and you need to change them.

  • double check the train_op runs after the loss:

    with ops.name_scope('train_op'):
        train_op = control_flow_ops.with_dependencies([train_op], total_loss)
    
like image 126
jorgemf Avatar answered Sep 23 '22 02:09

jorgemf