I am unsure how PyTorch manges to link the loss function to the model I want it to be computed for. There is never an explicit reference between the loss and the model, such as the one between the model's parameters and the optimizer.
Say for example I want to train 2 networks on the same dataset, so I want to utilize a single pass through the dataset. How would PyTorch link the appropriate loss functions to the appropriate models. Here's code for reference:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import shap
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = nn.Sequential(nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.LogSoftmax(dim=1))
model2 = nn.Sequential(nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.LogSoftmax(dim=1))
# Define the loss
criterion = nn.NLLLoss()
criterion2 = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003)
optimizer2 = optim.SGD(model2.parameters(), lr=0.003)
epochs = 5
for e in range(epochs):
running_loss = 0
running_loss_2 = 0
for images, labels in trainloader:
# Flatten MNIST images into a 784 long vector
images = images.view(images.shape[0], -1) # batch_size x total_pixels
# Training pass
optimizer.zero_grad()
optimizer2.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
output2 = model2(images)
loss2 = criterion2(output2, labels)
loss2.backward()
optimizer2.step()
running_loss += loss.item()
running_loss_2 += loss2.item()
print(f"Training loss 1: {running_loss/len(trainloader)}")
print(f"Training loss 2: {running_loss_2/len(trainloader)}")
print()
So once again, how does pytorch know how to compute the appropriate gradients for the appropriate models when loss.backward()
and loss2.backward()
are called?
Whenever you perform forward operations using one of your model parameters (or any torch.tensor
that has attribute requires_grad==True
), pytorch builds a computational graph. When you operate on descendents in this graph, the graph is extended. In your case, you have a nn.module
called model
which will have some trainable model.parameters()
, so pytorch will build a graph from your model.parameters()
all the way to the loss as you perform the forward operations. The graph is then traversed in reverse during the backward pass to propagate the gradients back to the parameters. For loss
in your code above the graph is something like
model.parameters() --> [intermediate variables in model] --> output --> loss
^ ^
| |
images labels
When you call loss.backward()
pytorch traverses this graph in reverse to reach all trainable parameters (only the model.parameters()
in this case) and updates param.grad
for each of them. The optimizer
then relies on this information gathered during the backward pass to update the parameter.
For loss2
the story is similar.
The official pytorch tutorials are a good resource for more in-depth information on this.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With