Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to convert all layers of a pretrained Keras model to a different dtype (from float32 to float16)?

I'm trying to change the precision of my (float32) model to float16 to see how much of a performance hit it takes. After loading a Model (base_model) I tried this:

from keras import backend as K
K.set_floatx('float16')
weights_list = base_model.layers[1].get_weights()
print('Original:')
print(weights_list[0].dtype)
new_weights = [K.cast_to_floatx(weights_list[0])]
print('New Weights:')
print(new_weights[0].dtype)
print('Setting New Weights')
base_model.layers[1].set_weights(new_weights)
new_weights_list = base_model.layers[1].get_weights()
print(new_weights_list[0].dtype)

Output:

Original:
float32
New Weights:
float16
Setting New Weights
float32

With this code, the weights within one layer are converted to float16, and the weights in the model are being set to the new weights, but after using get_weights, the data type goes back to float32. Is there a way to set a layer's dtype? From what I can tell, K.cast_to_floatx is for numpy arrays, and K.cast is for tensors. Do I need go through and construct and entirely new, empty model with the new dtype and put the recast weights in the new model?

Or is there some more straightforward way to load a model with all layers having dtype 'float32', and cast all layers to have dtype'float16'? This is a feature baked into mlmodel, so I figured it wouldn't be particularly difficult in Keras.

like image 864
Micah Price Avatar asked Dec 19 '17 21:12

Micah Price


1 Answers

Had the same question and got this working. What did not work for me:

  • Saving to file and loading back
  • Casting all weights and reassigning to original model

Here is what did work for me:

  • Creating a new model of the same architecture and setting its weights manually

MWE:

>>> from keras import backend as K
>>> from keras.models import Sequential
>>> from keras.layers import Dense, Dropout, Activation
>>> import numpy as np
>>> 
>>> def make_model():
...     model = Sequential()
...     model.add(Dense(64, activation='relu', input_dim=20))
...     model.add(Dropout(0.5))
...     model.add(Dense(64, activation='relu'))
...     model.add(Dropout(0.5))
...     model.add(Dense(10, activation='softmax'))
...     return model
... 
>>> K.set_floatx('float64')
>>> model = make_model()
>>> 
>>> K.set_floatx('float32')
>>> ws = model.get_weights()
>>> wsp = [w.astype(K.floatx()) for w in ws]
>>> model_quant = make_model()
>>> model_quant.set_weights(wsp)
>>> xp = x.astype(K.floatx())
>>> 
>>> print(np.unique([w.dtype for w in model.get_weights()]))
[dtype('float64')]
>>> print(np.unique([w.dtype for w in model_quant.get_weights()]))
[dtype('float32')]
like image 83
craymichael Avatar answered Oct 14 '22 20:10

craymichael