The lookahead optimizer is proposed very recently (https://arxiv.org/abs/1907.08610), and seems very promising (even Hinton is one of the authors!).
Therefore, I wonder how to implement it in keras? I think many people may need the answer to this question, because we want to plug it into our existing code and it may produce better results.
I am new in keras, so any help would be truly appreciated. Thanks!
p.s. source code of existing optimizers: https://github.com/keras-team/keras/blob/master/keras/optimizers.py
For demonstrating the concept behind, one might implement the Lookahead Optimizer in a Keras callback, see my implementation here https://github.com/kpe/params-flow/blob/master/params_flow/optimizers/lookahead.py
def on_train_batch_end(self, batch, logs=None):
self.count += 1
if self.slow_weights is None:
self.slow_weights = self.model.trainable_weights
else:
if self.count % self.k == 0:
slow_ups, fast_ups = [], []
for fast, slow in zip(self.model.trainable_weights,
self.slow_weights):
slow_ups.append(K.update(slow, slow + self.alpha * (fast - slow)))
fast_ups.append(K.update(fast, slow))
K.batch_get_value(slow_ups)
K.batch_get_value(fast_ups)
What this does is conceptually embarrassingly simple - every k
updates the weights would be moved halfway (alpha=0.5
) towards what their value was k
iterations ago.
N.B. The above implementation might not work that well on a GPU or TPU, as the slow_weights
copy of the weights would probably get updated on the CPU (and moving the weights takes time).
EDIT (2020.03): There is an official implementation in tensorflow! https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/Lookahead
Today when I want to start implementing it, I found that somebody has already done it! (Of course, when I asked this question, it cannot be Googled.)
Here is the link: https://github.com/bojone/keras_lookahead (For the non-Chinese readers, I have slightly modified the repo: https://github.com/fzyzcjy/keras_lookahead.)
And the usage is like:
model.compile(optimizer=Adam(1e-3), loss='mse') # Any optimizer
lookahead = Lookahead(k=5, alpha=0.5) # Initialize Lookahead
lookahead.inject(model) # add into model
Looking into his code, the core of the implementation is the modification of the model.train_function
, i.e. model.train_function = ...
, to achieve the two sets of updates.
In addition, it seems that the "hacking" trick of the repo comes from the following article (judging from his code and comments): https://kexue.fm/archives/5879/comment-page-1 (Sorry it is a non-English page)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With