Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Access deprecated attribute "validation_data" in tf.keras.callbacks.Callback

I decided to switch from keras to tf.keras (as recommended here). Therefore I installed tf.__version__=2.0.0 and tf.keras.__version__=2.2.4-tf. In an older version of my code (using some older Tensorflow version tf.__version__=1.x.x) I used a callback to compute custom metrics on the entire validation data at the end of each epoch. The idea to do so was taken from here. However, it seems as if the "validation_data" attribute is deprecated so that the following code is not working any longer.

class ValMetrics(Callback):

    def on_train_begin(self, logs={}):

        self.val_all_mse = []

    def on_epoch_end(self, epoch, logs):

        val_predict = np.asarray(self.model.predict(self.validation_data[0]))
        val_targ = self.validation_data[1]

        val_epoch_mse = mse_score(val_targ, val_predict)

        self.val_epoch_mse.append(val_epoch_mse)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_epoch_mse"] = val_epoch_mse

        print(f"\nEpoch: {epoch + 1}")
        print("-----------------")
        print("val_mse:     {:+.6f}".format(val_epoch_mse))

        return

My current workaround is the following. I simply gave validation_data as an argument to the ValMetrics class :

class ValMetrics(Callback):

    def __init__(self, validation_data):
        super(Callback, self).__init__()
        self.X_val, self.y_val = validation_data

Still I have some questions: Is the "validation_data" attribute really deprecated or can it be found elsewhere? Is there a better way to access the validation data at the end of each epoch than with the above workaround?

Thanks a lot!

like image 437
Constih Avatar asked Feb 05 '20 16:02

Constih


1 Answers

You are right that the argument, validation_data is deprecated as per Tensorflow Callbacks Documentation.

The issue which you are facing has been raised in Github. Related issues are Issue1, Issue2 and Issue3.

None of the above Github Issues is resolved and Your workaround of passing Validation_Data as an argument to Custom Callback is a good one, as per this Github Comment, as many people found it useful.

Specifying the code of workaround below, for the benefit of the Stackoverflow Community, even though it is present in Github.

class Metrics(Callback):

    def __init__(self, val_data, batch_size = 20):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size

    def on_train_begin(self, logs={}):
        print(self.validation_data)
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs={}):
        batches = len(self.validation_data)
        total = batches * self.batch_size

        val_pred = np.zeros((total,1))
        val_true = np.zeros((total))

        for batch in range(batches):
            xVal, yVal = next(self.validation_data)
            val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
            val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal

        val_pred = np.squeeze(val_pred)
        _val_f1 = f1_score(val_true, val_pred)
        _val_precision = precision_score(val_true, val_pred)
        _val_recall = recall_score(val_true, val_pred)

        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)

        return

I will keep following the Github Issues mentioned above and will update the Answer accordingly.

Hope this helps. Happy Learning!

like image 121
Tensorflow Warrior Avatar answered Sep 26 '22 00:09

Tensorflow Warrior