Logo Questions Linux Laravel Mysql Ubuntu Git Menu

PyTorch: is there a definitive training loop similar to Keras' fit()?

I'm coming over from Keras to PyTorch, and one of the surprising things I've found is that I'm supposed to implement my own training loop.

In Keras, there is a de facto fit() function that: (1) runs gradient descent and (2) collects a history of metrics for loss and accuracy over both the training set and validation set.

In PyTorch, it appears that the programmer needs to implement the training loop. Since I'm new to PyTorch, I don't know if my training loop implementation is correct. I just want to compare apples-to-apples loss and accuracy metrics with what I'm seeing in Keras.

I've already read through:

  1. the official PyTorch 60-minute blitz, where they provide a sample training loop.

  2. official PyTorch example code, where I've found the training loop placed in-line with other code.

  3. the O'Reilly book Programming PyTorch for Deep Learning with its own training loop.

  4. Stanford CS230 sample code.

  5. various blog posts (e.g. here and here).

So I'm wondering: is there a definitive, universal training loop implementation that does the same thing and reports the same numbers as the Keras fit() function?

My points of frustration:

  1. Pulling data out of the dataloader is not consistent between image data and NLP data.

  2. Correctly computing loss and accuracy is not consistent in any sample code I've seen.

  3. Some code examples use Variable, while others do not.

  4. Unnecessarily detailed: moving data to/from the GPU; knowing when to call zero_grad().

For what it's worth, here is my current implementation. Are there any obvious bugs?

import time

def train(model, optimizer, loss_fn, train_dl, val_dl, epochs=20, device='cuda'):
    Runs training loop for classification problems. Returns Keras-style
    per-epoch history of loss and accuracy over training and validation data.

    model : nn.Module
        Neural network model
    optimizer : torch.optim.Optimizer
        Search space optimizer (e.g. Adam)
    loss_fn :
        Loss function (e.g. nn.CrossEntropyLoss())
    train_dl : 
        Iterable dataloader for training data.
    val_dl :
        Iterable dataloader for validation data.
    epochs : int
        Number of epochs to run
    device : string
        Specifies 'cuda' or 'cpu'

        Similar to Keras' fit(), the output dictionary contains per-epoch
        history of training loss, training accuracy, validation loss, and
        validation accuracy.

    print('train() called: model=%s, opt=%s(lr=%f), epochs=%d, device=%s\n' % \
          (type(model).__name__, type(optimizer).__name__,
           optimizer.param_groups[0]['lr'], epochs, device))

    history = {} # Collects per-epoch loss and acc like Keras' fit().
    history['loss'] = []
    history['val_loss'] = []
    history['acc'] = []
    history['val_acc'] = []

    start_time_sec = time.time()

    for epoch in range(epochs):

        # --- TRAIN AND EVALUATE ON TRAINING SET -----------------------------
        train_loss         = 0.0
        num_train_correct  = 0
        num_train_examples = 0

        for batch in train_dl:


            x    = batch[0].to(device)
            y    = batch[1].to(device)
            yhat = model(x)
            loss = loss_fn(yhat, y)


            train_loss         += loss.data.item() * x.size(0)
            num_train_correct  += (torch.max(yhat, 1)[1] == y).sum().item()
            num_train_examples += x.shape[0]

        train_acc   = num_train_correct / num_train_examples
        train_loss  = train_loss / len(train_dl.dataset)

        # --- EVALUATE ON VALIDATION SET -------------------------------------
        val_loss       = 0.0
        num_val_correct  = 0
        num_val_examples = 0

        for batch in val_dl:

            x    = batch[0].to(device)
            y    = batch[1].to(device)
            yhat = model(x)
            loss = loss_fn(yhat, y)

            val_loss         += loss.data.item() * x.size(0)
            num_val_correct  += (torch.max(yhat, 1)[1] == y).sum().item()
            num_val_examples += y.shape[0]

        val_acc  = num_val_correct / num_val_examples
        val_loss = val_loss / len(val_dl.dataset)

        print('Epoch %3d/%3d, train loss: %5.2f, train acc: %5.2f, val loss: %5.2f, val acc: %5.2f' % \
              (epoch+1, epochs, train_loss, train_acc, val_loss, val_acc))



    end_time_sec       = time.time()
    total_time_sec     = end_time_sec - start_time_sec
    time_per_epoch_sec = total_time_sec / epochs
    print('Time total:     %5.2f sec' % (total_time_sec))
    print('Time per epoch: %5.2f sec' % (time_per_epoch_sec))

    return history
like image 795
stackoverflowuser2010 Avatar asked Jan 03 '20 19:01


2 Answers

Indeed, the Pytorch Module class (source code) doesn't have a fit() method, so you have to implement your own according to your needs. However there are some implementations which mimic the Keras training API, such as this one:


or a simpler one:


like image 88
Andreas K. Avatar answered Oct 17 '22 02:10

Andreas K.

Short answer: there is no equivalent training loop for PT and TF.keras and there shall never be one.

First of all, the training loop is syntactical sugar that is supposed to makes one's life easier. From my point of view, "making life easier" is a moto of TF.keras framework and this is the main reason it has it. Training loop can not be formalized as well defined practice, it might vary a lot depending on the task/dataset/procedure/metric/you_name_it and it would require a lot of effort to match all the options for 2 frameworks. Furthermore, creating a defining interface for training loop in Pytorch might be too restrictive for many actual users of the framework.

Matching the outputs of network would require matching behaviors of every operation within 2 frameworks, which would be impossible. First of all, the frameworks don't necessarily provide same sets of operations. Operations can be grouped into higher level abstracts differently. Also, some common functions like sigmoid or BatchNorm might look well mathematically defined on paper, but in reality have dozens of implementation specific details. Also, when improvements are introduced to the operations it is up to the community to integrate these updates into main framework distributions or plane ignore them. Needless to say, developers of 2 frameworks make these decisions independently and likely have different motivation behind them.

To sum it all up, matching high level details of 2 frameworks would require enormous effort and would probably be very disruptive for the existing users.

like image 27
y.selivonchyk Avatar answered Oct 17 '22 03:10
