Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Early stopping in Bert Trainer instances

I am fine tuning a BERT model for a multiclass classification task. My problem is that I don't know how to add "early stopping" to those Trainer instances. Any ideas?

like image 949
soulwreckedyouth Avatar asked Sep 07 '21 11:09

soulwreckedyouth


People also ask

How early can you stop working?

These early stopping rules work by splitting the original training set into a new training set and a validation set. The error on the validation set is used as a proxy for the generalization error in determining when overfitting has begun. These methods are most commonly employed in the training of neural networks.

How does PyTorch apply early stopping in training?

PyTorch ignite early stoppingpatience is used to wait the number of event if no improvement is shown then stop the training. score_function is used as a function which take a single argument and return a score float. trainer is used to stop the run when no improvement is done.

What is patience in early stopping?

Patience is an important parameter of the Early Stopping Callback. If the patience parameter is set to X number of epochs or iterations, then the training will terminate only if there is no improvement in the monitor performance measure for X epochs or iterations in a row.

Does PyTorch have early stopping?

Early stopping is a form of regularization used to avoid overfitting on the training dataset. Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops.


Video Answer


1 Answers

There are a couple of modifications you need to perform, prior to correctly using the EarlyStoppingCallback().

from transformers import EarlyStoppingCallback, IntervalStrategy
...
...
# Defining the TrainingArguments() arguments
args = TrainingArguments(
   f"training_with_callbacks",
   evaluation_strategy = IntervalStrategy.STEPS, # "steps"
   eval_steps = 50, # Evaluation and Save happens every 50 steps
   save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
   learning_rate=2e-5,
   per_device_train_batch_size=batch_size,
   per_device_eval_batch_size=batch_size,
   num_train_epochs=5,
   weight_decay=0.01,
   push_to_hub=False,
   metric_for_best_model = 'f1',
   load_best_model_at_end=True)

You need to:

  1. Use load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
  2. evaluation_strategy = 'steps' or IntervalStrategy.STEPS instead of 'epoch'.
  3. eval_steps = 50 (evaluate the metrics after N steps).
  4. metric_for_best_model = 'f1',

In your Trainer():

trainer = Trainer(
    model,
    args,
    ...
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

Of course, when you use compute_metrics(), for example it can be a function like:

def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred)
    precision = precision_score(y_true=labels, y_pred=pred)
    f1 = f1_score(y_true=labels, y_pred=pred)    
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

The return of the compute_metrics() should be a dictionary and you can access whatever metric you want/compute inside the function and return.

Note: In newer transformers version, the usage of Enum IntervalStrategy.steps is recommended (see TrainingArguments()) instead of plain steps string, the latter being soon subject to deprecation.

like image 137
Timbus Calin Avatar answered Oct 26 '22 07:10

Timbus Calin