How to reset optimizer state in keras?
Looking at Optimizer class I can't see such a method: https://github.com/keras-team/keras/blob/613aeff37a721450d94906df1a3f3cc51e2299d4/keras/optimizers.py#L60
Also what is actually self.updates
and self.weights
?
There isn't an "easy" way to reset the "states", but you can always simply recompile your model with a new optimizer (model's weights are preserved):
newOptimizer = Adadelta()
model.compile(optimizer=newOptimizer)
You can also use the method set_weights(weightsListInNumpy)
(not recommended), in the base class Optimizer
, but this would be rather cumbersome as you would need to know all initial values and shapes, which sometimes may not be trivial zeroes .
Now, the property self.weights
doesn't do much, but the functions that save and load optimizers will save and load this property. It's a list of tensors and should not be changed directly. At most use K.set_value(...)
in each entry of the list. You can see the weights
in saving the optimizer in the _serialize_model
method.
The self.updates
are something a little more complex to understand. It stores the variables that will be updated with every batch that is processed by the model in training. But it's a symbolic graph variable.
The self.updates
, as you can see in the code, is always appended with a K.update(var, value)
or K.update_add(var, value)
. This is the correct way to tell the graph that these values should be updated every iteration.
Usually, the updated vars are iterations
, params
(the model's weights), moments
, accumulators
, etc.
I don't think there is a universal method for this, but you should be able to reset the state of your optimizer by initializing the variables holding it. This would need to be done with the TensorFlow API, though. The state variables depend on the specific kind of optimizer. For example, if you have a Adam
optimizer (source), you could do the following:
from keras.optimizers import Adam
from keras import backend as K
optimizer = Adam(...)
# These depend on the optimizer class
optimizer_state = [optimizer.iterations, optimizer.lr, optimizer.beta_1,
optimizer.beta_2, optimizer.decay]
optimizer_reset = tf.variables_initializer(optimizer_state)
# Later when you want to reset the optimizer
K.get_session().run(optimizer_reset)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With