Training a tf.keras model with a basic low-level TensorFlow training loop doesn't work

Note: All code for a self-contained example to reproduce my problem can be found below.

I have a tf.keras.models.Model instance and need to train it with a training loop written in the low-level TensorFlow API.

The problem: Training the exact same tf.keras model once with a basic, standard low-level TensorFlow training loop and once with Keras' own model.fit() method produces very different results. I would like to find out what I'm doing wrong in my low-level TF training loop.

The model is a simple image classification model that I train on Caltech256 (link to tfrecords below).

With the low-level TensorFlow training loop, the training loss first decreases as it should, but then after just 1000 training steps, the loss plateaus and then starts increasing again:

Training the same model on the same dataset using the normal Keras training loop, on the other hand, works as expected:

What am I missing in my low-level TensorFlow training loop?

Here is the code to reproduce the problem (download the TFRecords with the link at the bottom):

import tensorflow as tf
from tqdm import trange
import sys
import glob
import os

sess = tf.Session()

num_classes = 257
image_size = (224, 224, 3)

# Build a tf.data.Dataset from TFRecords.

tfrecord_directory = 'path/to/tfrecords/directory'

tfrecord_filennames = glob.glob(os.path.join(tfrecord_directory, '*.tfrecord'))

feature_schema = {'image': tf.FixedLenFeature([], tf.string),
                  'filename': tf.FixedLenFeature([], tf.string),
                  'label': tf.FixedLenFeature([], tf.int64)}

dataset = tf.data.Dataset.from_tensor_slices(tfrecord_filennames)
dataset = dataset.shuffle(len(tfrecord_filennames)) # Shuffle the TFRecord file names.
dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename))
dataset = dataset.map(lambda single_example_proto: tf.parse_single_example(single_example_proto, feature_schema)) # Deserialize tf.Example objects.
dataset = dataset.map(lambda sample: (sample['image'], sample['label']))
dataset = dataset.map(lambda image, label: (tf.image.decode_jpeg(image, channels=3), label)) # Decode JPEG images.
dataset = dataset.map(lambda image, label: (tf.image.resize_image_with_pad(image, target_height=image_size[0], target_width=image_size[1]), label))
dataset = dataset.map(lambda image, label: (tf.image.per_image_standardization(image), label))
dataset = dataset.map(lambda image, label: (image, tf.one_hot(indices=label, depth=num_classes))) # Convert labels to one-hot format.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat()
dataset = dataset.batch(32)

iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

# Build a simple model.

input_tensor = tf.keras.layers.Input(shape=image_size)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(input_tensor)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(x)
x = tf.keras.layers.Conv2D(128, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(x)
x = tf.keras.layers.Conv2D(256, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(num_classes, activation=None, kernel_initializer='he_normal')(x)
model = tf.keras.models.Model(input_tensor, x)

This is the simple TensorFlow training loop:

# Build the training-relevant part of the graph.

model_output = model(features)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))

train_op = tf.train.AdamOptimizer().minimize(loss)

# The next block is for the metrics.
with tf.variable_scope('metrics') as scope:
    predictions_argmax = tf.argmax(model_output, axis=-1, output_type=tf.int64)
    labels_argmax = tf.argmax(labels, axis=-1, output_type=tf.int64)
    mean_loss_value, mean_loss_update_op = tf.metrics.mean(loss)
    acc_value, acc_update_op = tf.metrics.accuracy(labels=labels_argmax, predictions=predictions_argmax)
    local_metric_vars = tf.contrib.framework.get_variables(scope=scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    metrics_reset_op = tf.variables_initializer(var_list=local_metric_vars)

# Run the training

epochs = 3
steps_per_epoch = 1000

fetch_list = [mean_loss_value,


with sess.as_default():

    for epoch in range(1, epochs+1):

        tr = trange(steps_per_epoch, file=sys.stdout)
        tr.set_description('Epoch {}/{}'.format(epoch, epochs))


        for train_step in tr:

            ret = sess.run(fetch_list, feed_dict={tf.keras.backend.learning_phase(): 1})

            tr.set_postfix(ordered_dict={'loss': ret[0],
                                         'accuracy': ret[1]})

Below is the standard Keras training loop, which works as expected. Note that the activation of the dense layer in the model above needs to be changed from None to 'softmax' in order for the Keras loop to work.

epochs = 3
steps_per_epoch = 1000


history = model.fit(dataset,

You can download the TFRecords for the Caltech256 dataset here (about 850 MB).


I've managed to solve the problem: Replacing the low-level TF loss function

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))

by its Keras equivalent

loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=labels, output=model_output, from_logits=True))

does the trick. Now the low-level TensorFlow training loop behaves just like model.fit().

This raises a new question:

What does tf.keras.backend.categorical_crossentropy() do that tf.nn.softmax_cross_entropy_with_logits_v2() doesn't that leads the latter to perform much worse? (I know that the latter needs logits, not softmax output, so that's not the issue)

1 Answers

Replacing the low-level TF loss function

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))

by its Keras equivalent

loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=labels, output=model_output, from_logits=True))

does the trick. Now the low-level TensorFlow training loop behaves just like model.fit().

However, I don't know why this is. If anyone knows why tf.keras.backend.categorical_crossentropy() behaves well while tf.nn.softmax_cross_entropy_with_logits_v2() doesn't work at all, please post an answer.

Another important note:

In order to train a tf.keras model with a low-level TF training loop and a tf.data.Dataset object, one generally shouldn't call the model on the iterator output. That is, one shouldn't do this:

model_output = model(features)

Instead, one should create a model in which the input layer is set to build on the iterator output instead of creating a placeholder, like so:

input_tensor = tf.keras.layers.Input(tensor=features)

This doesn't matter in this example, but it becomes relevant if any layers in the model have internal updates that need to be run during the training (e.g. BatchNormalization).

