Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why training loss is increased at the beginning of each epoch?

I'm training a linear regression model. I use tf.contrib.data to prepare the dataset, shuffle it and serve it in batches:

  dataset = tf.contrib.data.TFRecordDataset(filename)
  dataset = dataset.map(
      _parse_function, num_threads=16, output_buffer_size=100 * batch_size)
  dataset = dataset.repeat(5)
  dataset = dataset.shuffle(buffer_size=100000)
  dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]))
  iterator = dataset.make_initializable_iterator()
  x_inputs, y_ = iterator.get_next()

The following is our training loss: training loss

It is very strange that at the beginning of each epoch (iteration=100k) we have a pulse in the training loss. If the training process continues, we see the same pattern at the beginning of the following epochs.

like image 230
DNN SEARCH Avatar asked Oct 25 '17 09:10

DNN SEARCH


2 Answers

Assuming that your dataset has fewer than 100000 records, the problem could be insufficient randomness in the input shuffle. Intuitively, if the existing data are not shuffled and have some structure in their order, the training process could overfit to the records towards the end of the file, and when you restart at the beginning, the model will not perform as well on the records near the beginning of the file, and the loss will increase.

The proper solution will depend on the precise details of your dataset, but some of the following approaches might be useful:

  1. If possible, increase the buffer_size you pass to Dataset.shuffle() to be as large as the dataset. (This might not be possible if your entire dataset doesn't fit in memory.)

  2. Ensure that the data does not have any structure in the ordering of records by randomizing the order of the input file before training.

  3. In addition to randomizing the order within the file, you could partition the data into several different files, and randomize the order in which you visit those files using the following:

    dataset = tf.data.Dataset.list_files(shard_filenames)
    dataset = dataset.shuffle(len(shard_filenames))
    dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename))
    # ...
    
  4. As an extension of (3), you could use Dataset.interleave() in place of Dataset.flat_map() to read multiple shards at once .

like image 56
mrry Avatar answered Nov 10 '22 06:11

mrry


The same sawtooth pattern is seen here: https://discuss.pytorch.org/t/strange-behavior-with-sgd-momentum-training/7442

The suggested solution there is to insure your data loader loads the data WITH replacement.

like image 25
Kevinj22 Avatar answered Nov 10 '22 06:11

Kevinj22