Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras load_model with custom objects doesn't work properly

Setting

As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model. My loss looks as follows:

def weighted_cross_entropy(weights):

    weights = K.variable(weights)

    def loss(y_true, y_pred):
        y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())

        loss = y_true * K.log(y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss

    return loss

weighted_loss = weighted_cross_entropy([0.1,0.9])

So during training, I used the weighted_loss function as loss function and everything worked well. When training is finished I save the model as .h5file with the standard model.save function from keras API.

Problem

When I am trying to load the model via

model = load_model(path,custom_objects={"weighted_loss":weighted_loss})

I am getting a ValueError telling me that the loss is unknown.

Error

The error message looks as follows:

File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss

Questions

How can I fix this problem? May it be possible that the reason for that is my wrapped loss definition? So keras doesn't know, how to handle the weights variable?

like image 814
pafi Avatar asked Mar 18 '19 13:03

pafi


1 Answers

Your loss function's name is loss (i.e. def loss(y_true, y_pred):). Therefore, when loading back the model you need to specify 'loss' as its name:

model = load_model(path, custom_objects={'loss': weighted_loss})
like image 118
today Avatar answered Sep 28 '22 10:09

today