I'm trying to change the precision of my (float32) model to float16 to see how much of a performance hit it takes. After loading a Model (base_model) I tried this:
from keras import backend as K
K.set_floatx('float16')
weights_list = base_model.layers[1].get_weights()
print('Original:')
print(weights_list[0].dtype)
new_weights = [K.cast_to_floatx(weights_list[0])]
print('New Weights:')
print(new_weights[0].dtype)
print('Setting New Weights')
base_model.layers[1].set_weights(new_weights)
new_weights_list = base_model.layers[1].get_weights()
print(new_weights_list[0].dtype)
Output:
Original:
float32
New Weights:
float16
Setting New Weights
float32
With this code, the weights within one layer are converted to float16, and the weights in the model are being set to the new weights, but after using get_weights, the data type goes back to float32. Is there a way to set a layer's dtype? From what I can tell, K.cast_to_floatx is for numpy arrays, and K.cast is for tensors. Do I need go through and construct and entirely new, empty model with the new dtype and put the recast weights in the new model?
Or is there some more straightforward way to load a model with all layers having dtype 'float32', and cast all layers to have dtype'float16'? This is a feature baked into mlmodel, so I figured it wouldn't be particularly difficult in Keras.
Had the same question and got this working. What did not work for me:
Here is what did work for me:
MWE:
>>> from keras import backend as K
>>> from keras.models import Sequential
>>> from keras.layers import Dense, Dropout, Activation
>>> import numpy as np
>>>
>>> def make_model():
... model = Sequential()
... model.add(Dense(64, activation='relu', input_dim=20))
... model.add(Dropout(0.5))
... model.add(Dense(64, activation='relu'))
... model.add(Dropout(0.5))
... model.add(Dense(10, activation='softmax'))
... return model
...
>>> K.set_floatx('float64')
>>> model = make_model()
>>>
>>> K.set_floatx('float32')
>>> ws = model.get_weights()
>>> wsp = [w.astype(K.floatx()) for w in ws]
>>> model_quant = make_model()
>>> model_quant.set_weights(wsp)
>>> xp = x.astype(K.floatx())
>>>
>>> print(np.unique([w.dtype for w in model.get_weights()]))
[dtype('float64')]
>>> print(np.unique([w.dtype for w in model_quant.get_weights()]))
[dtype('float32')]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With