I'm training a linear regression model. I use tf.contrib.data to prepare the dataset, shuffle it and serve it in batches:
dataset = tf.contrib.data.TFRecordDataset(filename)
dataset = dataset.map(
_parse_function, num_threads=16, output_buffer_size=100 * batch_size)
dataset = dataset.repeat(5)
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]))
iterator = dataset.make_initializable_iterator()
x_inputs, y_ = iterator.get_next()
The following is our training loss:
It is very strange that at the beginning of each epoch (iteration=100k) we have a pulse in the training loss. If the training process continues, we see the same pattern at the beginning of the following epochs.
Assuming that your dataset has fewer than 100000 records, the problem could be insufficient randomness in the input shuffle. Intuitively, if the existing data are not shuffled and have some structure in their order, the training process could overfit to the records towards the end of the file, and when you restart at the beginning, the model will not perform as well on the records near the beginning of the file, and the loss will increase.
The proper solution will depend on the precise details of your dataset, but some of the following approaches might be useful:
If possible, increase the buffer_size
you pass to Dataset.shuffle()
to be as large as the dataset. (This might not be possible if your entire dataset doesn't fit in memory.)
Ensure that the data does not have any structure in the ordering of records by randomizing the order of the input file before training.
In addition to randomizing the order within the file, you could partition the data into several different files, and randomize the order in which you visit those files using the following:
dataset = tf.data.Dataset.list_files(shard_filenames)
dataset = dataset.shuffle(len(shard_filenames))
dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename))
# ...
As an extension of (3), you could use Dataset.interleave()
in place of Dataset.flat_map()
to read multiple shards at once .
The same sawtooth pattern is seen here: https://discuss.pytorch.org/t/strange-behavior-with-sgd-momentum-training/7442
The suggested solution there is to insure your data loader loads the data WITH replacement.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With