I have a multi-label classification problem with 5 labels (e.g. [1 0 1 1 0]
). Therefore, I want my model to improve at metrics such as fixed recall, precision-recall AUC or ROC AUC.
It doesn't make sense to use a loss function (e.g. binary_crossentropy
) that is not directly related to the performance measurement I want to optimize. Therefore, I want to use TensorFlow's global_objectives.recall_at_precision_loss()
or similar as loss function.
I'm not looking for implementing a tf.metrics
. I already succeeded in that following: https://stackoverflow.com/a/50566908/3399066
I think my issue can be divided into 2 problems:
global_objectives.recall_at_precision_loss()
or similar?There is a file called loss_layers_example.py
on the global objectives GitHub page (same as above). However, since I don't have much experience with TF, I don't really understand how to use it. Also, Googling for TensorFlow recall_at_precision_loss example
or TensorFlow Global objectives example
won't give me any clearer example.
How do I use global_objectives.recall_at_precision_loss()
in a simple TF example?
Would something like (in Keras): model.compile(loss = ??.recall_at_precision_loss, ...)
be enough?
My feeling tells me it is more complex than that, due to the use of global variables used in loss_layers_example.py
.
How to use loss functions similar to global_objectives.recall_at_precision_loss()
in Keras?
I managed to make it work by:
def precision_recall_auc_loss(y_true, y_pred):
y_true = keras.backend.reshape(y_true, (BATCH_SIZE, 1))
y_pred = keras.backend.reshape(y_pred, (BATCH_SIZE, 1))
util.get_num_labels = lambda labels : 1
return loss_layers.precision_recall_auc_loss(y_true, y_pred)[0]
Similar to Martino's answer, but will infer shape from input (setting it to a fixed batch size did not work for me).
The outside function isn't strictly necessary, but it feels a bit more natural to pass params as you configure the loss function, especially when your wrapper is defined in an external module.
import keras.backend as K
from global_objectives.loss_layers import precision_at_recall_loss
def get_precision_at_recall_loss(target_recall):
def precision_at_recall_loss_wrapper(y_true, y_pred):
y_true = K.reshape(y_true, (-1, 1))
y_pred = K.reshape(y_pred, (-1, 1))
return precision_at_recall_loss(y_true, y_pred, target_recall)[0]
return precision_at_recall_loss_wrapper
Then, when compiling the model:
TARGET_RECALL = 0.9
model.compile(optimizer='adam', loss=get_precision_at_recall_loss(TARGET_RECALL))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With