Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch - inference all images and back-propagate batch by batch

I have a special use case that I have to separate inference and back-propagation: I have to inference all images and slice outputs into batches followed by back-propagating batches by batches. I don't need to update my network's weights.

I modified snippets of cifar10_tutorial into the following to simulate my problem: j is a variable to represent the index which returns by my own logic and I want the gradient of some variables.

for epoch in range(2):  # loop over the dataset multiple times

    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs.requires_grad = True

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)

        for j in range(4): # j is given by external logic in my own case

            loss = criterion(outputs[j, :].unsqueeze(0), labels[j].unsqueeze(0))

            loss.backward()

            print(inputs.grad.data[j, :]) # what I really want

I got the following errors:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

My questions are:

  1. According to my understanding, the problem arises because the first back-propagate backwards the whole outputs and outputs[1,:].unsqueeze(0) was released so second back-propagate failed. Am I right?

  2. In my case, if I set retain_graph=True, will the code run slower and slower according to this post?

  3. Is there better way to achieve my goal?

like image 894
Tengerye Avatar asked Dec 19 '18 02:12

Tengerye


People also ask

How does backpropagation work PyTorch?

Backpropagation is used to calculate the gradients of the loss with respect to the input weights to later update the weights and eventually reduce the loss. Creating and training a neural network involves the following essential steps: Define the architecture. Forward propagate on the architecture using input data.

What does backward do in PyTorch?

Computes the gradient of current tensor w.r.t. graph leaves. The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying gradient .


1 Answers

  1. Yes you are correct. When you already back-propagated through outputs the first time (first iteration), the buffers will be freed and it will fail the following time (next iteration of your loop), because then necessary data for this computation have already been removed.

  2. Yes, the graph grows bigger and bigger, so it could be slower depending on GPU (or CPU) usage and your network. I had used this once and it was much slower, however this depends much on your network architecture. But certainly you will need more memory with retain_graph=True than without.

  3. Depending on your outputs and labels shape you should be able to calculate the loss for all your outputs and labels at once:

    criterion(outputs, labels)
    

    You have to skip the j-loop then, this would also make your code faster. Maybe you need to reshape (resp. view) your data, but this should work fine.

    If you for some reason cannot do that you can manually sum up the loss on a tensor and call backward after the loop. This should work fine too, but is slower than the solution above.

    So than your code would look like this:

    # init loss tensor
    loss = torch.tensor(0.0) # move to GPU if you're using one
    
    for j in range(4):
        # summing up your loss for every j
        loss += criterion(outputs[j, :].unsqueeze(0), labels[j].unsqueeze(0))
        # ...
    # calling backward on the summed loss - getting gradients
    loss.backward()
    # as you call backward now only once on the outputs
    # you shouldn't get any error and you don't have to use retain_graph=True
    

Edit:

The accumulation of the losses and calling later backward is completely equivalent, here is a small example with and without accumulating the losses:

First creating some data data:

# w in this case will represent a very simple model
# I leave out the CE and just use w to map the output to a scalar value
w = torch.nn.Linear(4, 1)
data = [torch.rand(1, 4) for j in range(4)]

data looks like:

[tensor([[0.4593, 0.3410, 0.1009, 0.9787]]),
 tensor([[0.1128, 0.0678, 0.9341, 0.3584]]),
 tensor([[0.7076, 0.9282, 0.0573, 0.6657]]),
 tensor([[0.0960, 0.1055, 0.6877, 0.0406]])]

Let's first do like you're doing it, calling backward for every iteration j separately:

# code for directly applying backward
# zero the weights layer w
w.zero_grad()
for j, inp in enumerate(data):
    # activate grad flag
    inp.requires_grad = True
    # remove / zero previous gradients for inputs
    inp.grad = None
    # apply model (only consists of one layer in our case)
    loss = w(inp)
    # calling backward on every output separately
    loss.backward()
    # print out grad
    print('Input:', inp)
    print('Grad:', inp.grad)
    print()
print('w.weight.grad:', w.weight.grad)

Here is the print-out with every input and the respective gradient and gradients for the model resp. layer w in our simplified case:

Input: tensor([[0.4593, 0.3410, 0.1009, 0.9787]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.1128, 0.0678, 0.9341, 0.3584]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.7076, 0.9282, 0.0573, 0.6657]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.0960, 0.1055, 0.6877, 0.0406]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

w.weight.grad: tensor([[1.3757, 1.4424, 1.7801, 2.0434]])

Now instead of calling backward once for every iteration j we accumulate the values and call backward on the sum and compare the results:

# init tensor for accumulation
loss = torch.tensor(0.0)
# zero layer gradients
w.zero_grad()
for j, inp in enumerate(data):
    # activate grad flag
    inp.requires_grad = True
    # remove / zero previous gradients for inputs
    inp.grad = None
    # apply model (only consists of one layer in our case)
    # accumulating values instead of calling backward
    loss += w(inp).squeeze()
# calling backward on the sum
loss.backward()

# printing out gradients 
for j, inp in enumerate(data):
    print('Input:', inp)
    print('Grad:', inp.grad)
    print()
print('w.grad:', w.weight.grad)

Lets take a look at the results:

Input: tensor([[0.4593, 0.3410, 0.1009, 0.9787]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.1128, 0.0678, 0.9341, 0.3584]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.7076, 0.9282, 0.0573, 0.6657]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.0960, 0.1055, 0.6877, 0.0406]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

w.grad: tensor([[1.3757, 1.4424, 1.7801, 2.0434]])

When comparing the results we can see that both are the same.
This is a very simple example, but nevertheless we can see that calling backward() on every single tensor and summing up tensors and then calling backward() is equivalent in terms of the resulting gradients for both inputs and weights.

When you use CE for all j 's at once as described in 3. you can use the flag reduction='sum' to archive the same behaviour like above with summing up the CE values, default is ‘mean’, which probably leads to slightly different results.

like image 169
MBT Avatar answered Oct 11 '22 18:10

MBT