Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do loss functions know for which model to compute gradients in PyTorch?

I am unsure how PyTorch manges to link the loss function to the model I want it to be computed for. There is never an explicit reference between the loss and the model, such as the one between the model's parameters and the optimizer.

Say for example I want to train 2 networks on the same dataset, so I want to utilize a single pass through the dataset. How would PyTorch link the appropriate loss functions to the appropriate models. Here's code for reference:

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import shap

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                              ])
# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

model = nn.Sequential(nn.Linear(784, 128),
                      nn.ReLU(),
                      nn.Linear(128, 64),
                      nn.ReLU(),
                      nn.Linear(64, 10),
                      nn.LogSoftmax(dim=1))

model2 = nn.Sequential(nn.Linear(784, 128),
                      nn.ReLU(),
                      nn.Linear(128, 10),
                      nn.LogSoftmax(dim=1))

# Define the loss
criterion = nn.NLLLoss()
criterion2 = nn.NLLLoss()

optimizer = optim.SGD(model.parameters(), lr=0.003)
optimizer2 = optim.SGD(model2.parameters(), lr=0.003)

epochs = 5
for e in range(epochs):
    running_loss = 0
    running_loss_2 = 0
    for images, labels in trainloader:
        # Flatten MNIST images into a 784 long vector
        images = images.view(images.shape[0], -1) # batch_size x total_pixels

        # Training pass
        optimizer.zero_grad()
        optimizer2.zero_grad()

        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()


        output2 = model2(images)
        loss2 = criterion2(output2, labels)
        loss2.backward()
        optimizer2.step()

        running_loss += loss.item()
        running_loss_2 += loss2.item()

    print(f"Training loss 1: {running_loss/len(trainloader)}")
    print(f"Training loss 2: {running_loss_2/len(trainloader)}")
    print()

So once again, how does pytorch know how to compute the appropriate gradients for the appropriate models when loss.backward() and loss2.backward() are called?

like image 636
Zafir Stojanovski Avatar asked Nov 14 '19 09:11

Zafir Stojanovski


1 Answers

Whenever you perform forward operations using one of your model parameters (or any torch.tensor that has attribute requires_grad==True), pytorch builds a computational graph. When you operate on descendents in this graph, the graph is extended. In your case, you have a nn.module called model which will have some trainable model.parameters(), so pytorch will build a graph from your model.parameters() all the way to the loss as you perform the forward operations. The graph is then traversed in reverse during the backward pass to propagate the gradients back to the parameters. For loss in your code above the graph is something like

model.parameters() --> [intermediate variables in model] -->  output --> loss
                                  ^                                        ^
                                  |                                        |
                               images                                     labels

When you call loss.backward() pytorch traverses this graph in reverse to reach all trainable parameters (only the model.parameters() in this case) and updates param.grad for each of them. The optimizer then relies on this information gathered during the backward pass to update the parameter. For loss2 the story is similar.

The official pytorch tutorials are a good resource for more in-depth information on this.

like image 121
hdkrgr Avatar answered Sep 30 '22 17:09

hdkrgr