I'm coming over from Keras to PyTorch, and one of the surprising things I've found is that I'm supposed to implement my own training loop.
In Keras, there is a de facto fit()
function that: (1) runs gradient descent and (2) collects a history of metrics for loss and accuracy over both the training set and validation set.
In PyTorch, it appears that the programmer needs to implement the training loop. Since I'm new to PyTorch, I don't know if my training loop implementation is correct. I just want to compare apples-to-apples loss and accuracy metrics with what I'm seeing in Keras.
I've already read through:
the official PyTorch 60-minute blitz, where they provide a sample training loop.
official PyTorch example code, where I've found the training loop placed in-line with other code.
the O'Reilly book Programming PyTorch for Deep Learning with its own training loop.
Stanford CS230 sample code.
various blog posts (e.g. here and here).
So I'm wondering: is there a definitive, universal training loop implementation that does the same thing and reports the same numbers as the Keras fit()
function?
My points of frustration:
Pulling data out of the dataloader is not consistent between image data and NLP data.
Correctly computing loss and accuracy is not consistent in any sample code I've seen.
Some code examples use Variable
, while others do not.
Unnecessarily detailed: moving data to/from the GPU; knowing when to call zero_grad()
.
For what it's worth, here is my current implementation. Are there any obvious bugs?
import time
def train(model, optimizer, loss_fn, train_dl, val_dl, epochs=20, device='cuda'):
'''
Runs training loop for classification problems. Returns Keras-style
per-epoch history of loss and accuracy over training and validation data.
Parameters
----------
model : nn.Module
Neural network model
optimizer : torch.optim.Optimizer
Search space optimizer (e.g. Adam)
loss_fn :
Loss function (e.g. nn.CrossEntropyLoss())
train_dl :
Iterable dataloader for training data.
val_dl :
Iterable dataloader for validation data.
epochs : int
Number of epochs to run
device : string
Specifies 'cuda' or 'cpu'
Returns
-------
Dictionary
Similar to Keras' fit(), the output dictionary contains per-epoch
history of training loss, training accuracy, validation loss, and
validation accuracy.
'''
print('train() called: model=%s, opt=%s(lr=%f), epochs=%d, device=%s\n' % \
(type(model).__name__, type(optimizer).__name__,
optimizer.param_groups[0]['lr'], epochs, device))
history = {} # Collects per-epoch loss and acc like Keras' fit().
history['loss'] = []
history['val_loss'] = []
history['acc'] = []
history['val_acc'] = []
start_time_sec = time.time()
for epoch in range(epochs):
# --- TRAIN AND EVALUATE ON TRAINING SET -----------------------------
model.train()
train_loss = 0.0
num_train_correct = 0
num_train_examples = 0
for batch in train_dl:
optimizer.zero_grad()
x = batch[0].to(device)
y = batch[1].to(device)
yhat = model(x)
loss = loss_fn(yhat, y)
loss.backward()
optimizer.step()
train_loss += loss.data.item() * x.size(0)
num_train_correct += (torch.max(yhat, 1)[1] == y).sum().item()
num_train_examples += x.shape[0]
train_acc = num_train_correct / num_train_examples
train_loss = train_loss / len(train_dl.dataset)
# --- EVALUATE ON VALIDATION SET -------------------------------------
model.eval()
val_loss = 0.0
num_val_correct = 0
num_val_examples = 0
for batch in val_dl:
x = batch[0].to(device)
y = batch[1].to(device)
yhat = model(x)
loss = loss_fn(yhat, y)
val_loss += loss.data.item() * x.size(0)
num_val_correct += (torch.max(yhat, 1)[1] == y).sum().item()
num_val_examples += y.shape[0]
val_acc = num_val_correct / num_val_examples
val_loss = val_loss / len(val_dl.dataset)
print('Epoch %3d/%3d, train loss: %5.2f, train acc: %5.2f, val loss: %5.2f, val acc: %5.2f' % \
(epoch+1, epochs, train_loss, train_acc, val_loss, val_acc))
history['loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['acc'].append(train_acc)
history['val_acc'].append(val_acc)
# END OF TRAINING LOOP
end_time_sec = time.time()
total_time_sec = end_time_sec - start_time_sec
time_per_epoch_sec = total_time_sec / epochs
print()
print('Time total: %5.2f sec' % (total_time_sec))
print('Time per epoch: %5.2f sec' % (time_per_epoch_sec))
return history
Indeed, the Pytorch Module
class (source code) doesn't have a fit() method, so you have to implement your own according to your needs.
However there are some implementations which mimic the Keras training API, such as this one:
https://github.com/ncullen93/torchsample
or a simpler one:
https://github.com/henryre/pytorch-fitmodule
Short answer: there is no equivalent training loop for PT and TF.keras and there shall never be one.
First of all, the training loop is syntactical sugar that is supposed to makes one's life easier. From my point of view, "making life easier" is a moto of TF.keras framework and this is the main reason it has it. Training loop can not be formalized as well defined practice, it might vary a lot depending on the task/dataset/procedure/metric/you_name_it and it would require a lot of effort to match all the options for 2 frameworks. Furthermore, creating a defining interface for training loop in Pytorch might be too restrictive for many actual users of the framework.
Matching the outputs of network would require matching behaviors of every operation within 2 frameworks, which would be impossible. First of all, the frameworks don't necessarily provide same sets of operations. Operations can be grouped into higher level abstracts differently. Also, some common functions like sigmoid or BatchNorm might look well mathematically defined on paper, but in reality have dozens of implementation specific details. Also, when improvements are introduced to the operations it is up to the community to integrate these updates into main framework distributions or plane ignore them. Needless to say, developers of 2 frameworks make these decisions independently and likely have different motivation behind them.
To sum it all up, matching high level details of 2 frameworks would require enormous effort and would probably be very disruptive for the existing users.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With