Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to implement early stopping in tensorflow

def train(): # Model model = Model()  # Loss, Optimizer global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step') loss_fn = model.loss() optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)  # Summaries summary_op = summaries(model, loss_fn)  with tf.Session(config=TrainConfig.session_conf) as sess:      # Initialized, Load state     sess.run(tf.global_variables_initializer())     model.load_state(sess, TrainConfig.CKPT_PATH)      writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)      # Input source     data = Data(TrainConfig.DATA_PATH)      loss = Diff()     for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):              mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)              mixed_spec = to_spectrogram(mixed_wav)             mixed_mag = get_magnitude(mixed_spec)              src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)             src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)              src1_batch, _ = model.spec_to_batch(src1_mag)             src2_batch, _ = model.spec_to_batch(src2_mag)             mixed_batch, _ = model.spec_to_batch(mixed_mag)              # Initializae our callback.             #early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)               l, _, summary = sess.run([loss_fn, optimizer, summary_op],                                      feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,                                                 model.y_src2: src2_batch})              loss.update(l)             print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))              writer.add_summary(summary, global_step=step)              # Save state             if step % TrainConfig.CKPT_STEP == 0:                 tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)      writer.close() 

I have this neural network code that separates music from a voice in a .wav file. how can I introduce an early stopping algorithm to stop the train section? I see some project that talks about a ValidationMonitor. Can someone help me?

like image 578
Valerio Avatar asked Sep 26 '17 14:09

Valerio


People also ask

How does early stopping work Keras?

Early stopping is a method that allows you to specify an arbitrarily large number of training epochs and stop training once the model performance stops improving on the validation dataset.

When should I stop TensorFlow training?

Training will stop if the model doesn't show improvement over the baseline. Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.


1 Answers

Here is my implementation of the early stopping u can adapt it:

The early stopping can be applied at certain stages of the training process, such as at the end of each epoch. Specifically; in my case; I monitor the test (validation) loss at each epoch and after the test loss has not improved after 20 epochs (self.require_improvement= 20) , the training is interrupted.

You can set the max epochs to 10000 or 20000 or whatever you want (self.max_epochs = 10000).

  self.require_improvement= 20   self.max_epochs = 10000 

Here is my training function where I use the early stopping:

def train(self):

# training data     train_input = self.Normalize(self.x_train)     train_output = self.y_train.copy()             #===============     save_sess=self.sess # this used to compare the result of previous sess with actual one # ===============   #costs history :     costs = []     costs_inter=[] # =================   #for early stopping :     best_cost=1000000      stop = False     last_improvement=0 # ================     n_samples = train_input.shape[0] # size of the training set # ===============    #train the mini_batches model using the early stopping criteria     epoch = 0     while epoch < self.max_epochs and stop == False:         #train the model on the traning set by mini batches         #suffle then split the training set to mini-batches of size self.batch_size         seq =list(range(n_samples))         random.shuffle(seq)         mini_batches = [             seq[k:k+self.batch_size]             for k in range(0,n_samples, self.batch_size)         ]          avg_cost = 0. # The average cost of mini_batches         step= 0          for sample in mini_batches:              batch_x = x_train.iloc[sample, :]             batch_y =train_output.iloc[sample, :]             batch_y = np.array(batch_y).flatten()              feed_dict={self.X: batch_x,self.Y:batch_y, self.is_train:True}              _, cost,acc=self.sess.run([self.train_step, self.loss_, self.accuracy_],feed_dict=feed_dict)             avg_cost += cost *len(sample)/n_samples              print('epoch[{}] step [{}] train -- loss : {}, accuracy : {}'.format(epoch,step, avg_cost, acc))             step += 100          #cost history since the last best cost         costs_inter.append(avg_cost)          #early stopping based on the validation set/ max_steps_without_decrease of the loss value : require_improvement         if avg_cost < best_cost:             save_sess= self.sess # save session             best_cost = avg_cost             costs +=costs_inter # costs history of the validatio set             last_improvement = 0             costs_inter= []         else:             last_improvement +=1         if last_improvement > self.require_improvement:             print("No improvement found during the ( self.require_improvement) last iterations, stopping optimization.")             # Break out from the loop.             stop = True             self.sess=save_sess # restore session with the best cost          ## Run validation after every epoch :          print('---------------------------------------------------------')         self.y_validation = np.array(self.y_validation).flatten()         loss_valid, acc_valid = self.sess.run([self.loss_,self.accuracy_],                                                feed_dict={self.X: self.x_validation, self.Y: self.y_validation,self.is_train: True})         print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}".format(epoch + 1, loss_valid, acc_valid))         print('---------------------------------------------------------')          epoch +=1 

We can resume the important code here :

def train(self):   ...       #costs history :         costs = []         costs_inter=[]       #for early stopping :         best_cost=1000000          stop = False         last_improvement=0        #train the mini_batches model using the early stopping criteria         epoch = 0         while epoch < self.max_epochs and stop == False:             ...             for sample in mini_batches:             ...                                #cost history since the last best cost             costs_inter.append(avg_cost)              #early stopping based on the validation set/ max_steps_without_decrease of the loss value : require_improvement             if avg_cost < best_cost:                 save_sess= self.sess # save session                 best_cost = avg_cost                 costs +=costs_inter # costs history of the validatio set                 last_improvement = 0                 costs_inter= []             else:                 last_improvement +=1             if last_improvement > self.require_improvement:                 print("No improvement found during the ( self.require_improvement) last iterations, stopping optimization.")                 # Break out from the loop.                 stop = True                 self.sess=save_sess # restore session with the best cost             ...             epoch +=1 

Hope it will help someone :).

like image 80
DINA TAKLIT Avatar answered Sep 30 '22 10:09

DINA TAKLIT