Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using batch norm when restore the model?

I have a little problem that using the batch norm when restore the model in tensorflow.

Below is my batch norm which from here:

def _batch_normalization(self, input_tensor, is_training, batch_norm_epsilon, decay=0.999):
    """batch normalization for dense nets.

    Args:
        input_tensor: `tensor`, the input tensor which needed normalized.
        is_training: `bool`, if true than update the mean/variance using moving average,
                             else using the store mean/variance.
        batch_norm_epsilon: `float`, param for batch normalization.
        decay: `float`, param for update move average, default is 0.999.

    Returns:
        normalized params.
    """
    # actually batch normalization is according to the channels dimension.
    input_shape_channels = int(input_tensor.get_shape()[-1])

    # scala and beta using in the the formula like that: scala * (x - E(x))/sqrt(var(x)) + beta
    scale = tf.Variable(tf.ones([input_shape_channels]))
    beta = tf.Variable(tf.zeros([input_shape_channels]))

    # global mean and var are the mean and var that after moving averaged.
    global_mean = tf.Variable(tf.zeros([input_shape_channels]), trainable=False)
    global_var = tf.Variable(tf.ones([input_shape_channels]), trainable=False)

    # if training, then update the mean and var, else using the trained mean/var directly.
    if is_training:
        # batch norm in the channel axis.
        axis = list(range(len(input_tensor.get_shape()) - 1))
        batch_mean, batch_var = tf.nn.moments(input_tensor, axes=axis)

        # update the mean and var.
        train_mean = tf.assign(global_mean, global_mean * decay + batch_mean * (1 - decay))
        train_var = tf.assign(global_var, global_var * decay + batch_var * (1 - decay))
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(input_tensor,
                                             batch_mean, batch_var, beta, scale, batch_norm_epsilon)
    else:
        return tf.nn.batch_normalization(input_tensor,
                                         global_mean, global_var, beta, scale, batch_norm_epsilon)

I train the model and save it using tf.train.Saver(). Below is the test code:

def inference(self, images_for_predict):
    """load the pre-trained model and do the inference.

    Args:
        images_for_predict: `tensor`, images for predict using the pre-trained model.

    Returns:
        the predict labels.
    """

    tf.reset_default_graph()
    images, labels, _, _, prediction, accuracy, saver = self._build_graph(1, False)

    predictions = []
    correct = 0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # saver = tf.train.import_meta_graph('./models/dense_nets_model/dense_nets.ckpt.meta')
        # saver.restore(sess, tf.train.latest_checkpoint('./models/dense_nets_model/'))
        saver.restore(sess, './models/dense_nets_model/dense_nets.ckpt')
        for i in range(100):
            pred, corr = sess.run([tf.argmax(prediction, 1), accuracy],
                                  feed_dict={
                                      images: [images_for_predict.images[i]],
                                      labels: [images_for_predict.labels[i]]})
            correct += corr
            predictions.append(pred[0])
    print("PREDICTIONS:", predictions)
    print("ACCURACY:", correct / 100)

But the predict result always very bad, like that:

('PREDICTIONS:', [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

('ACCURACY:', 0.080000000000000002)

Some tips: images_for_predict = mnist.test and the self._build_graph method has two params: batch_size and is_training.

Anyone can help me?

like image 621
Yang Avatar asked Oct 18 '17 11:10

Yang


2 Answers

After trying a lot of methods, I solve this problem, below are what I did.

First thanks to @gdelab, I used tf.layers.batch_normalization instead, so my batch norm function like that:

def _batch_normalization(self, input_tensor, is_training):
    return tf.layers.batch_normalization(input_tensor, training=is_training)

The param is_training is a placeholder like that: is_training = tf.placeholder(tf.bool)

when building your graph, remember to add this code in your optimize:

extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
    train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(cross_entropy)

because the tf.layers.batch_normalization adds to update mean and variance don't automatically get added as dependencies of the train operation - so if you don't do anything extra, they never get run.

So begain to train the net, after finish the training, save the model using the code like that:

saver = tf.train.Saver(var_list=tf.global_variables())
savepath = saver.save(sess, 'here_is_your_personal_model_path')

Note that var_list=tf.global_variables() param make sure tensorflow save all the params include the global mean/var which are set not trainable.

when restore and test the model, do like that:

# build the graph like training:
images, labels, _, _, prediction, accuracy, saver = self._build_graph(1, False)
saver = tf.train.Saver()
saver.restore(sess, 'here_is_your_personal_model_path')

And now one can test his/her model, hope that it can help u, thanks!

like image 178
Yang Avatar answered Oct 20 '22 06:10

Yang


Seeing your implementation of batch norm, when you load your model, you need to keep the graph built with images, labels, _, _, prediction, accuracy, saver = self._build_graph(1, False) and load the weight values for the chekpoint, but NOT the meta graph. I think that saver.restore(sess, './models/dense_nets_model/dense_nets.ckpt') also restores the meta graph now (sorry if I'm wrong), so you need to restore only the "data" part of it.

Otherwise, you're just using the graph for training, in which the mean and variance used in batch norm are the ones obtained from the batch. But when you're testing the batch has size 1, so normalizing by the mean and variance of the batch always brings your data to 0, hence the constant output.

In any case, I'd suggest using tf.layers.batch_normalization instead, with a is_training placeholder that you'll need to feed to your network...

like image 31
gdelab Avatar answered Oct 20 '22 04:10

gdelab