Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

L2 regularization keep increasing during training

I am finetuning InceptionResnetV2 on TensorFlow. When training, the regularization loss keep linearly increasing and even much larger than cross entropy loss in the later stage of training. I have checked the training procedure, and make sure I am optimizing the cross entropy loss and L2 loss combined.

Is there anyone explain this weird thing a little bit? Any feedback is appreciated.

Here is the code and some TensorBoard plots.

import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
import os
import time
from preprocessing import aug_parallel_v2
import numpy as np

slim = tf.contrib.slim

# total training data number
sample_num = 625020

data_path = 'iNaturalist_train.tfrecords'

# State where your log file is at. If it doesn't exist, create it.
log_dir = './log_v5'
# tensorboard visualization path
filewriter_path = './filewriter_v5_Logits'

# State where your checkpoint file is
checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
checkpoint_save_addr = './log_v5/fine-tuning_v5.ckpt'
# State the image size you're resizing your images to. We will use the default inception size of 299.
image_size = 299

# State the number of classes to predict:
num_classes = 8142

# ================= TRAINING INFORMATION ==================
# State the number of epochs to train
num_epochs = 5

# State your batch size
batch_size = 60

# Learning rate information and configuration
initial_learning_rate = 0.0005
learning_rate_decay_factor = 0.8
num_epochs_before_decay = 2

# put weight on different classes inversely proportional
# to total number of their image samples
label_count = np.loadtxt('label_count.txt', dtype=int)
inverse = lambda t: 1 / t
vfunc = np.vectorize(inverse)
multiplier = vfunc(label_count)
multiplier /= np.mean(multiplier)

def run():

    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    feature = {'train/height': tf.FixedLenFeature([], tf.int64),
               'train/width': tf.FixedLenFeature([], tf.int64),
               'train/image': tf.FixedLenFeature([], tf.string),
               'train/label': tf.FixedLenFeature([], tf.int64),
               'train/sup_label': tf.FixedLenFeature([], tf.int64),
               'train/aug_level': tf.FixedLenFeature([], tf.int64)}

    # create a list of file names
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=None)
    print(filename_queue)

    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(filename_queue)

    features = tf.parse_single_example(tfrecord_serialized, features=feature)

    # Convert the image data from string back to the numbers
    height = tf.cast(features['train/height'], tf.int64)
    width = tf.cast(features['train/width'], tf.int64)

    # change this line for your TFrecord version
    tf_image = tf.image.decode_jpeg(features['train/image'])

    tf_label = tf.cast(features['train/label'], tf.int32)
    aug_level = tf.cast(features['train/aug_level'], tf.int32)
    # tf_sup_label = tf.cast(features['train/sup_label'], tf.int64)

    tf_image = tf.reshape(tf_image, tf.stack([height, width, 3]))
    tf_label = tf.reshape(tf_label, [1])
    aug_level = tf.reshape(aug_level, [1])

    resized_image = tf.image.resize_images(images=tf_image, size=tf.constant([400, 400]), method=2)
    resized_image = tf.cast(resized_image, tf.uint8)
    tf_images, tf_labels, tf_aug = tf.train.shuffle_batch([resized_image, tf_label, aug_level], batch_size=batch_size,
                                                      capacity=2048, num_threads=16, allow_smaller_final_batch=False,
                                                      min_after_dequeue=256)


    tf.logging.set_verbosity(tf.logging.INFO)  # Set the verbosity to INFO level

    IMAGE_HEIGHT = 299
    IMAGE_WIDTH = 299

    images = tf.placeholder(dtype=tf.float32, shape=[None, 299, 299, 3])
    labels = tf.placeholder(dtype=tf.int32, shape=[None, 1])
    weighted_level = tf.placeholder(dtype=tf.float32, shape=[None, 1])

    # Know the number steps to take before decaying the learning rate and batches per epoch
    num_batches_per_epoch = int(sample_num / batch_size)
    num_steps_per_epoch = num_batches_per_epoch  # Because one step is one batch processed
    decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)

    # Create the model inference
    with slim.arg_scope(inception_resnet_v2_arg_scope()):
        logits, end_points = inception_resnet_v2(images, num_classes=num_classes, is_training=True)

    # Define the scopes that you want to exclude for restoration
    exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
    variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

    print("label test")
    print(labels)
    print(logits)

    # Perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!)
    one_hot_labels = tf.squeeze(tf.one_hot(labels, num_classes), [1])

    print(one_hot_labels)
    print(logits)

    weighted_onehot = tf.multiply(one_hot_labels, weighted_level)

    # Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
    digits_loss = tf.losses.softmax_cross_entropy(onehot_labels=weighted_onehot, logits=logits)

    reg_loss = tf.losses.get_regularization_loss()

    total_loss = digits_loss + reg_loss

    # Define your exponentially decaying learning rate
    lr = tf.train.exponential_decay(
        learning_rate=initial_learning_rate,
        global_step=global_step,
        decay_steps=decay_steps,
        decay_rate=learning_rate_decay_factor,
        staircase=True)

    # train_vars = []
    # Now we can define the optimizer that takes on the learning rate
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          "InceptionResnetV2/Logits")

    # RMSProp or Adam

    optimizer = tf.train.AdamOptimizer(learning_rate=lr)

    # Create the train_op.
    train_op = slim.learning.create_train_op(total_loss, optimizer, variables_to_train=train_vars)

    predictions = tf.argmax(end_points['Predictions'], 1)
    probabilities = end_points['Predictions']
    accuracy, accuracy_update = tf.metrics.accuracy(predictions, labels)
    metrics_op = tf.group(accuracy_update, probabilities)

    tf.summary.scalar('losses/Reg_Loss', reg_loss)
    tf.summary.scalar('losses/Digit_Loss', digits_loss)
    tf.summary.scalar('losses/Total_Loss', total_loss)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('learning_rate', lr)
    writer = tf.summary.FileWriter(filewriter_path)
    writer.add_graph(tf.get_default_graph())

    my_summary_op = tf.summary.merge_all()

    def train_step(sess, train_op, global_step, imgs, lbls, weight):
        '''
        Simply runs a session for the three arguments provided and gives a logging on the time elapsed
        for each global step
        '''
        # Check the time for each sess run
        start_time = time.time()

        total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op],
                                                    feed_dict={images: imgs, labels: lbls, weighted_level: weight})

        time_elapsed = time.time() - start_time

        # Run the logging to print some results
        logging.info('global step %s: digit_loss: %.4f (%.2f sec/step)',
                     global_step_count, total_loss, time_elapsed)

        return total_loss, global_step_count

    saver_pretrain = tf.train.Saver(variables_to_restore)
    saver_train = tf.train.Saver(train_vars)

    with tf.Session() as sess:

        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)

        # Create a coordinator and run all QueueRunner objects
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        saver_pretrain.restore(sess, checkpoint_file)

        start_time = time.time()

        for step in range(int(num_steps_per_epoch * num_epochs)):

            imgs, lbls, augs = sess.run([tf_images, tf_labels, tf_aug])

            imgs, lbls = aug_parallel_v2(imgs, lbls, augs)

            imgs = imgs[:, 50:349, 50:349, :]

            imgs = 2*(imgs.astype(np.float32)) - 1

            lbls = lbls.astype(np.int32)

            weight = multiplier[lbls]

            weight = np.array(weight).reshape((batch_size, 1))

            # print(imgs[0, 0:10, 0:10, 0:2])

            if step % num_batches_per_epoch == 0:
                logging.info('Epoch %s/%s', step / num_batches_per_epoch + 1, num_epochs)

                learning_rate_value, accuracy_value = sess.run([lr, accuracy],
                                                feed_dict={images: imgs, labels: lbls, weighted_level: weight})

                logging.info('Current Learning Rate: %s', learning_rate_value)
                logging.info('Current Streaming Accuracy: %s', accuracy_value)

                # optionally, print your logits and predictions for a sanity check that things are going fine.
                logits_value, probabilities_value, predictions_value, labels_value = sess.run(
                    [logits, probabilities, predictions, labels],
                    feed_dict={images: imgs, labels: lbls, weighted_level: weight})

                print('logits: \n', logits_value)

                print('Probabilities: \n', probabilities_value)

                print('predictions: \n', predictions_value)

                print('Labels:\n:', labels_value)

            # Log the summaries every 10 step.
            if step % 20 == 0:

                loss, global_step_count = train_step(sess, train_op, global_step, imgs, lbls, weight)

                summaries = sess.run(my_summary_op, feed_dict={images: imgs, labels: lbls, weighted_level: weight})

                writer.add_summary(summaries, global_step_count)
                # sess.summary_computed(sess, summaries)

            # If not, simply run the training step

            else:
                loss, _ = train_step(sess, train_op, global_step, imgs, lbls, weight)

            if step % 2000 == 0:

                logging.info('Saving model to disk now.')
                saver_train.save(sess, checkpoint_save_addr, global_step=global_step)

            print('one batch time: ', time.time() - start_time)

            start_time = time.time()

        # We log the final training loss and accuracy
        logging.info('Final Loss: %s', loss)
        logging.info('Final Accuracy: %s', sess.run(accuracy))

        # Once all the training has been done, save the log files and checkpoint model
        logging.info('Finished training! Saving model to disk now.')
        saver_train.save(sess, checkpoint_save_addr, global_step=global_step)

        # Stop the threads
        coord.request_stop()

        # Wait for threads to stop
        coord.join(threads)
        sess.close()

if __name__ == '__main__':
    run()

I am new here, and don't have enough reputation to post images. Here are two links for the accuracy plot and losses plot. You can easily tell the regularization loss is in a dominant position.

enter image description here

enter image description here

like image 319
Tong Shen Avatar asked Oct 14 '25 05:10

Tong Shen


1 Answers

This is a difficult question to answer. I can give some pointers though.

In general, when you try to minimize digits_loss, that is to fit your model to your data, you will slowly change the weights in your layers. To counter potential overfitting, a L2 regularization loss (the sum of the squares of all weights, reg_loss in your code) is generally added to the overall loss (total_loss in your code.) These two forces generally act against each other and if the balance is right, you train a good model.

In your case you're taking a network (resnet_v2) that was developed for 1,001 classes and try to predict 8,142 classes. No problem with that per se, but you're upsetting the balance. So I believe you need to override the default weight decay of 0.00004 for resnet v2 to some higher value, in this line (note only 3 zeros in the decimals for a 10x increase):

with slim.arg_scope( inception_resnet_v2_arg_scope( weight_decay = 0.0004 ) ):

A higher weight_decay parameter will force the L2 loss to decrease faster. The problem is that this number is just a guess, I have no idea what an ideal value would be. You need to experiment with multiple values and figure it out.

like image 166
Peter Szoldan Avatar answered Oct 19 '25 22:10

Peter Szoldan



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!