I am trying to comprehend inner workings of the gradient accumulation in PyTorch
. My question is somewhat related to these two:
Why do we need to call zero_grad() in PyTorch?
Why do we need to explicitly call zero_grad()?
Comments to the accepted answer to the second question suggest that accumulated gradients can be used if a minibatch is too large to perform a gradient update in a single forward pass, and thus has to be split into multiple sub-batches.
Consider the following toy example:
import numpy as np
import torch
class ExampleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
# Initialize the weight at 1
self.weight = torch.nn.Parameter(torch.Tensor([1]).float(),
requires_grad=True)
def forward(self, x):
return self.weight * x
if __name__ == "__main__":
# Example 1
model = ExampleLinear()
# Generate some data
x = torch.from_numpy(np.array([4, 2])).float()
y = 2 * x
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
y_hat = model(x) # forward pass
loss = (y - y_hat) ** 2
loss = loss.mean() # MSE loss
loss.backward() # backward pass
optimizer.step() # weight update
print(model.weight.grad) # tensor([-20.])
print(model.weight) # tensor([1.2000]
Which is exactly the result one would expect. Now assume that we want to process the dataset sample-by-sample utilizing gradient accumulation:
# Example 2: MSE sample-by-sample
model2 = ExampleLinear()
optimizer = torch.optim.SGD(model2.parameters(), lr=0.01)
# Compute loss sample-by-sample, then average it over all samples
loss = []
for k in range(len(y)):
y_hat = model2(x[k])
loss.append((y[k] - y_hat) ** 2)
loss = sum(loss) / len(y)
loss.backward() # backward pass
optimizer.step() # weight update
print(model2.weight.grad) # tensor([-20.])
print(model2.weight) # tensor([1.2000]
Again as expected, the gradient is calculated when the .backward()
method is called.
Finally to my question: what exactly happens 'under the hood'?
My understanding is that the computational graph is dynamically updated going from <PowBackward>
to <AddBackward>
<DivBackward>
operations for the loss
variable, and that no information about the data used for each forward pass is retained anywhere except for the loss
tensor which can be updated until the backward pass.
Are there any caveats to the reasoning in the above paragraph? Lastly, are there any best practices to follow when using gradient accumulation (i.e. can the approach I use in Example 2 backfire somehow)?
Gradient accumulation is a technique where you can train on bigger batch sizes than your machine would normally be able to fit into memory. This is done by accumulating gradients over several batches, and only stepping the optimizer after a certain number of batches have been performed.
Gradient accumulation is a particularly good option where there's only access to a single GPU, because it can be run sequentially on the single resource. Although the concept is simple, the mathematics and code required to implement gradient accumulation can be complicated.
To compute the gradients, a tensor must have its parameter requires_grad = true. The gradients are same as the partial derivatives. For example, in the function y = 2*x + 1, x is a tensor with requires_grad = True. We can compute the gradients using y.
The training speed can be accelerated when combining DDP and gradient accumulation. When applying gradient accumulation, the optimizer. step() is called every K steps intead of every step. And as we know every training step (with loss.
You are not actually accumulating gradients. Just leaving off optimizer.zero_grad()
has no effect if you have a single .backward()
call, as the gradients are already zero to begin with (technically None
but they will be
automatically initialised to zero).
The only difference between your two versions, is how you calculate the final loss. The for loop of the second example does the same calculations as PyTorch does in the first example, but you do them individually, and PyTorch cannot optimise (parallelise and vectorise) your for loop, which makes an especially staggering difference on GPUs, granted that the tensors aren't tiny.
Before getting to gradient accumulation, let's start with your question:
Finally to my question: what exactly happens 'under the hood'?
Every operation on tensors is tracked in a computational graph if and only if one of the operands is already part of a computational graph. When you set requires_grad=True
of a tensor, it creates a computational graph with a single vertex, the tensor itself, which will remain a leaf in the graph. Any operation with that tensor will create a new vertex, which is the result of the operation, hence there is an edge from the operands to it, tracking the operation that was performed.
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(4.0)
c = a + b # => tensor(6., grad_fn=<AddBackward0>)
a.requires_grad # => True
a.is_leaf # => True
b.requires_grad # => False
b.is_leaf # => True
c.requires_grad # => True
c.is_leaf # => False
Every intermediate tensor automatically requires gradients and has a grad_fn
, which is the function to calculate the partial derivatives with respect to its inputs. Thanks to the chain rule, we can traverse the whole graph in reverse order to calculate the derivatives with respect to every single leaf, which are the parameters we want to optimise. That's the idea of backpropagation, also known as reverse mode differentiation. For more details I recommend reading Calculus on Computational Graphs: Backpropagation.
PyTorch uses that exact idea, when you call loss.backward()
it traverses the graph in reverse order, starting from loss
, and calculates the derivatives for each vertex. Whenever a leaf is reached, the calculated derivative for that tensor is stored in its .grad
attribute.
In your first example, that would lead to:
MeanBackward -> PowBackward -> SubBackward -> MulBackward`
The second example is almost identical, except that you calculate the mean manually, and instead of having a single path for the loss, you have multiple paths for each element of the loss calculation. To clarify, the single path also calculates the derivatives of each element, but internally, which again opens up the possibilities for some optimisations.
# Example 1
loss = (y - y_hat) ** 2
# => tensor([16., 4.], grad_fn=<PowBackward0>)
# Example 2
loss = []
for k in range(len(y)):
y_hat = model2(x[k])
loss.append((y[k] - y_hat) ** 2)
loss
# => [tensor([16.], grad_fn=<PowBackward0>), tensor([4.], grad_fn=<PowBackward0>)]
In either case a single graph is created that is backpropagated exactly once, that's the reason it's not considered gradient accumulation.
Gradient accumulation refers to the situation, where multiple backwards passes are performed before updating the parameters. The goal is to have the same model parameters for multiple inputs (batches) and then update the model's parameters based on all these batches, instead of performing an update after every single batch.
Let's revisit your example. x
has size [2], that's the size of our entire dataset. For some reason, we need to calculate the gradients based on the whole dataset. That is naturally the case when using a batch size of 2, since we would have the whole dataset at once. But what happens if we can only have batches of size 1? We could run them individually and update the model after each batch as usual, but then we don't calculate the gradients over the whole dataset.
What we need to do, is run each sample individually with the same model parameters and calculate the gradients without updating the model. Now you might be thinking, isn't that what you did in the second version? Almost, but not quite, and there is a crucial problem in your version, namely that you are using the same amount of memory as in the first version, because you have the same calculations and therefore the same number of values in the computational graph.
How do we free memory? We need to get rid of the tensors of the previous batch and also the computational graph, because that uses a lot of memory to keep track of everything that's necessary for the backpropagation. The computational graph is automatically destroyed when .backward()
is called (unless retain_graph=True
is specified).
def calculate_loss(x: torch.Tensor) -> torch.Tensor:
y = 2 * x
y_hat = model(x)
loss = (y - y_hat) ** 2
return loss.mean()
# With mulitple batches of size 1
batches = [torch.tensor([4.0]), torch.tensor([2.0])]
optimizer.zero_grad()
for i, batch in enumerate(batches):
# The loss needs to be scaled, because the mean should be taken across the whole
# dataset, which requires the loss to be divided by the number of batches.
loss = calculate_loss(batch) / len(batches)
loss.backward()
print(f"Batch size 1 (batch {i}) - grad: {model.weight.grad}")
print(f"Batch size 1 (batch {i}) - weight: {model.weight}")
# Updating the model only after all batches
optimizer.step()
print(f"Batch size 1 (final) - grad: {model.weight.grad}")
print(f"Batch size 1 (final) - weight: {model.weight}")
Output (I removed the Parameter containing messages for readability):
Batch size 1 (batch 0) - grad: tensor([-16.])
Batch size 1 (batch 0) - weight: tensor([1.], requires_grad=True)
Batch size 1 (batch 1) - grad: tensor([-20.])
Batch size 1 (batch 1) - weight: tensor([1.], requires_grad=True)
Batch size 1 (final) - grad: tensor([-20.])
Batch size 1 (final) - weight: tensor([1.2000], requires_grad=True)
As you can see, the model kept the same parameter for all batches, while the gradients were accumulate, and there is a single update at the end. Note that the loss needs to be scaled per batch, in order to have the same significance over the whole dataset as if you used a single batch.
While in this example, the whole dataset is used before performing the update, you can easily change that to update the parameters after a certain number of batches, but you have to remember to zero out the gradients after an optimiser step was taken. The general recipe would be:
accumulation_steps = 10
for i, batch in enumerate(batches):
# Scale the loss to the mean of the accumulated batch size
loss = calculate_loss(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
# Reset gradients, for the next accumulated batches
optimizer.zero_grad()
You can find that recipe and more techniques for working with large batch sizes in HuggingFace - Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With