The documentation for tf.train.MomentumOptimizer
offers a use_nesterov
parameter to utilise Nesterov's Accelerated Gradient (NAG) method.
However, NAG requires the gradient at a location other than that of the current variable to be calculated, and the apply_gradients
interface only allows for the current gradient to be passed. So I don't quite understand how the NAG algorithm could be implemented with this interface.
The documentation says the following about the implementation:
use_nesterov
: If True use Nesterov Momentum. See Sutskever et al., 2013. This implementation always computes gradients at the value of the variable(s) passed to the optimizer. Using Nesterov Momentum makes the variable(s) track the values calledtheta_t + mu*v_t
in the paper.
Having read through the paper in the link, I'm a little unsure about whether this description answers my question or not. How can the NAG algorithm be implemented when the interface doesn't require a gradient function to be provided?
Gradient descent is an optimization algorithm that uses the gradient of the objective function to navigate the search space. The convergence of gradient descent optimization algorithm can be accelerated by extending the algorithm and adding Nesterov Momentum.
Nadam Optimization Algorithm. The Nesterov-accelerated Adaptive Moment Estimation, or the Nadam, algorithm is an extension to the Adaptive Movement Estimation (Adam) optimization algorithm to add Nesterov's Accelerated Gradient (NAG) or Nesterov momentum, which is an improved type of momentum.
Nesterov Accelerated Gradient is a momentum-based SGD optimizer that "looks ahead" to where the parameters will be to calculate the gradient ex post rather than ex ante: v t = γ v t − 1 + η ∇ θ J ( θ − γ v t − 1 ) θ t = θ t − 1 + v t.
NAG is certainly making smaller oscillations/taking shorter U-turns even when approaching the minima valleys on the error surface. Looking ahead helps NAG in correcting its course quicker than momentum-based gradient descent. Hence the oscillations are smaller and the chances of escaping the minima valley also smaller.
TL;DR
TF's implementation of Nesterov is indeed an approximation of the original formula, valid for high values of momentum.
Details
This is a great question. In the paper, the NAG update is defined as
vt+1 = μ.vt - λ.∇f(θt + μ.vt)
θt+1 = θt + vt+1
where f
is our cost function, θt
our parameters at time t
, μ
the momentum, λ
the learning rate; vt
is the NAG's internal accumulator.
The main difference with standard momentum is the use of the gradient at θt + μ.vt
, not at θt
. But as you said, tensorflow only uses gradient at θt
. So what is the trick?
Part of the trick is actually mentioned in the part of the documentation you cited: the algorithm is tracking θt + μ.vt
, not θt
. The other part comes from an approximation valid for high value of momentum.
Let's make a slight change of notation from the paper for the accumulator to stick with tensorflow's definition. Let's define at = vt / λ
. The update rules are changed slightly as
at+1 = μ.at - ∇f(θt + μ.λ.at)
θt+1 = θt + λ.at+1
(The motivation for this change in TF is that now a
is a pure gradient momentum, independent of the learning rate. This makes the update process robust to changes in λ
, a possibility common in practice but that the paper does not consider.)
If we note ψt = θt + μ.λ.at
, then
at+1 = μ.at - ∇f(ψt)
ψt+1 = θt+1 + μ.λ.at+1
= θt + λ.at+1 + μ.λ.at+1
= ψt + λ.at+1 + μ.λ.(at+1 - at)
= ψt + λ.at+1 + μ.λ.[(μ-1)at - ∇f(ψt)]
≈ ψt + λ.at+1
This last approximation holds for strong values of momentum, where μ
is close to 1, so that μ-1
is close to zero, and ∇f(ψt)
is small compared to a
— this last approximation is more debatable actually, and less valid for directions with frequent gradient switch.
We now have an update that uses the gradient of the current position, and the rules are pretty simple — they are in fact those of standard momentum.
However, we want θt
, not ψt
. This is the reason why we subtract μ.λ.at+1
to ψt+1
just before returning it — and to recover ψ
it is added again first thing at the next call.
I couldn't see any info on this online, and the linked paper certainly wasn't helpful, so I had a look at the unit tests for tf.train.MomentumOptimizer
, from which I can see tests for the implementation of both classic momentum and NAG modes.
var = var + accum * learning_rate * momentum
accum = accum * momentum + g
var = var - learning_rate * accum
var = var - accum * learning_rate * momentum
where accum
starts at 0 and is updated at every step. The above is a modified version of the formulation in the unit test, and I find it a bit confusing. Here is the same set of equations arranged with my interpretation of what each of the parameters represent (I could be wrong though):
average_grad_0 = accum # previous rolling average
average_grad_1 = accum * momentum + g # updated rolling average
grad_diff = average_grad_1 - average_grad_0
adjustment = -learning_rate * (grad_diff * momentum + average_grad_1)
var += adjustment
accum = average_grad_new
In other words, it seems to me like tensorflow
's implementation attempts to guess the "adjusted gradient" in NAG by assuming that the new gradient will be esimated by the current average gradient plus the product of momentum and the change in the average gradient. I'd love to see a proof for this!
What follows is more detail on how the classic and nesterov modes are implemented in tensorflow
as per the tests.
For use_nesterov=False
, based on the doTestBasic
function, we have the following initial parameters:
learning_rate = 2.0
momentum = 0.9
var_0 = 1.0 # at time 0
grad = 0.1
Actually, the above are just the first element of the grads_0
and vars_0
arrays, but I'll just focus on a single value. For the subsequent timesteps, we have
var_1 = 1.0 - (0.1 * 2.0)
var_2 = 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
which I'm going to interpret as meaning;
var_1 = var_0 - (grad * learning_rate)
var_2 = var_1 - ((momentum * grad + grad) * learning_rate)
If we assume that for the purposes of the unit tests grad_0 == grad_1 == grad
then this makes sense as a formulation of classic momentum.
For use_nesterov=True
, I had a look at the _update_nesterov_momentum_numpy
function and the testNesterovMomentum
test case.
The _update_nesterov_momentum_numpy
function has the following definition:
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
var = var + accum * lr * momentum
accum = accum * momentum + g
var = var - lr * accum
var = var - accum * lr * momentum
return var, accum
and it is called in the unit tests like this:
for t in range(1, 5):
opt_op.run()
var0_np, accum0_np = self._update_nesterov_momentum_numpy(
var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With