Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make Keras compute a certain metric on validation data only?

I'm using tf.keras with TensorFlow 1.14.0. I have implemented a custom metric that is quite computationally intensive and it slows down the training process if I simply add it to the list of metrics provided as model.compile(..., metrics=[...]).

How do I make Keras skip computation of the metric during training iterations but compute it on validation data (and print it) at the end of each epoch?

like image 595
maruan Avatar asked Jun 30 '19 16:06

maruan


People also ask

How to define metrics in keras?

The easiest way of defining metrics in Keras is to simply use a function callback. The function takes two arguments. The first parameter is the ground truth (y_true) and the second is the prediction from the model (y_pred). While running the validation, these arguments are Tensors, so we have to use the Keras Backend for calculation.

How to monitor model performance in keras?

sometimes you want to monitor model performance by looking at charts like ROC curve or Confusion Matrix after every epoch. In Keras, metrics are passed during the compile stage as shown below. You can pass several metrics by comma separating them.

How do you calculate accuracy in keras?

Calculate Accuracy with Keras’ method. If (1) and (2) concur, attribute the logical definition to Keras’ method. Accuracy calculates the percentage of predicted values (yPred) that match with actual values (yTrue). For a record, if the predicted value is equal to the actual value, it is considered accurate.

What is meaniou in keras?

tf.keras.metrics.MeanIoU – Mean Intersection-Over-Union is a metric used for the evaluation of semantic image segmentation models. We first calculate the IOU for each class:


1 Answers

To do this you can create a tf.Variable in the metric calculation that determines if the calculation goes ahead and then update it when a test is run using a callback. e.g.

class MyCustomMetric(tf.keras.metrics.Metrics):

    def __init__(self, **kwargs):
        # Initialise as normal and add flag variable for when to run computation
        super(MyCustomMetric, self).__init__(**kwargs)
        self.metric_variable = self.add_weight(name='metric_varaible', initializer='zeros')
        self.update_metric = tf.Variable(False)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Use conditional to determine if computation is done
        if self.update_metric:
            # run computation
            self.metric_variable.assign_add(computation_result)

    def result(self):
        return self.metric_variable

    def reset_states(self):
        self.metric_variable.assign(0.)

class ToggleMetrics(tf.keras.callbacks.Callback):
    '''On test begin (i.e. when evaluate() is called or 
     validation data is run during fit()) toggle metric flag '''
    def on_test_begin(self, logs):
        for metric in self.model.metrics:
            if 'MyCustomMetric' in metric.name:
                metric.on.assign(True)
    def on_test_end(self,  logs):
        for metric in self.model.metrics:
            if 'MyCustomMetric' in metric.name:
                metric.on.assign(False)
like image 196
Cptn.Redbeard Avatar answered Sep 22 '22 16:09

Cptn.Redbeard