Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Save keras model weights directly to bytes/memory?

Tags:

Keras allows for saving entire models or just model weights (see thread). When saving the weights, they must be saved to a file, eg:

model = keras_model()
model.save_weights('/tmp/model.h5')

Instead of writing to file, I'd like to just save the bytes into memory. Something like

model.dump_weights()

Tensorflow doesn't seem to have this, so as a workaround I'm writing to disk and then reading into memory:

temp = '/tmp/weights.h5'
model.save_weights(temp)
with open(temp, 'rb') as f:
    weightbytes = f.read()

Any way to avoid this roundabout?

like image 278
Adam Hughes Avatar asked Mar 06 '20 16:03

Adam Hughes


2 Answers

weights=model.get_weights() will get the model weights. model.set_weights(weights) will set the model weights.One of the issues though is WHEN do you save the model weights. Generally you want to save the model weights for the epoch in which you had the lowest validation loss. The Keras callback ModelCheckpoint will save the weights with the lowest validation loss to a file. I found that saving to a file is inconvenient so I wrote a small custom callback to just save the weight with the lowest validation loss into a class variable then after training is complete load those weights into the model to make predictions. Code is shown below. Just add save_best_weights to the list of callbacks when you compile the model.

class save_best_weights(tf.keras.callbacks.Callback):
best_weights=model.get_weights()    
def __init__(self):
    super(save_best_weights, self).__init__()
    self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
    current_loss = logs.get('val_loss')
    accuracy=logs.get('val_accuracy')* 100
    if np.less(current_loss, self.best):
        self.best = current_loss            
        save_best_weights.best_weights=model.get_weights()
        print('\nSaving weights validation loss= {0:6.4f}  validation accuracy= {1:6.3f} %\n'.format(current_loss, accuracy))   

like image 197
Gerry P Avatar answered Oct 02 '22 15:10

Gerry P


Convert model to json, and use dill dump, then store the bytes file, you can use base64 to store to database if needed, save model weights as well, all happen in memory, no touching disk

from io import BytesIO
import dill,base64,tempfile

#Saving Model as base64
model_json = Keras_model.to_json()

def Base64Converter(ObjectFile):
    bytes_container = BytesIO()
    dill.dump(ObjectFile, bytes_container)
    bytes_container.seek(0)
    bytes_file = bytes_container.read()
    base64File = base64.b64encode(bytes_file)
    return base64File

base64KModelJson = Base64Converter(model_json)  
base64KModelJsonWeights = Base64Converter(Keras_model.get_weights())  

for loading back, use model_from_json, joblib and tempfile

#Loading Back
from joblib import load
from keras.models import model_from_json
def ObjectConverter(base64_File):
    loaded_binary = base64.b64decode(base64_File)
    loaded_object = tempfile.TemporaryFile()
    loaded_object.write(loaded_binary)
    loaded_object.seek(0)
    ObjectFile = load(loaded_object)
    loaded_object.close()
    return ObjectFile

modeljson = ObjectConverter(base64KModelJson)
modelweights = ObjectConverter(base64KModelJsonWeights)
loaded_model = model_from_json(modeljson)
loaded_model.set_weights(modelweights)
like image 24
hanzgs Avatar answered Oct 02 '22 15:10

hanzgs