Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Calculate recall for each class after each epoch in Tensorflow 2

I am trying to calculate the recall in both binary and multi class (one hot encoded) classification scenarios for each class after each epoch in a model that uses Tensorflow 2's Keras API. e.g. for binary classification I'd like to be able to do something like

import tensorflow as tf
model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(1))

model.compile(metrics=[binary_recall(label=0), binary_recall(label=1)], ...)
history = model.fit(...)

plt.plot(history.history['binary_recall_0'])
plt.plot(history.history['binary_recall_1'])
plt.show()

or in a multi class scenario I'd like to do something like

model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(3))

model.compile(metrics=[recall(label=0), recall(label=1), recall(label=2)], ...)
history = model.fit(...)

plt.plot(history.history['recall_0'])
plt.plot(history.history['recall_1'])
plt.plot(history.history['recall_2'])
plt.show()

I'm working on a classifier for an unbalanced dataset and want to be able to see at what point the recall for my minority class(s) starts to degrade.

I found an implementation of precision for a specific class in a multi-class classifier here https://stackoverflow.com/a/41717938/373655. I'm am trying to adapt this into what I need but keras.backend is still pretty foreign to me so any help would be greatly appreciated.

I am also not clear on if I can use Keras metrics (as they are calculated at the end of each batch and then averaged) or if I need to use Keras callbacks (which can run at the end of each epoch). It seems to me like it shouldn't make a difference for recall (e.g. 8/10 == (3/5 + 5/5) / 2) but this is why recall was removed in Keras 2 so maybe I'm missing something (https://github.com/keras-team/keras/issues/5794)

Edit - partial solution (multi-class classification) @mujjiga's solution works for both binary classification and multi-class classification but as @P-Gn pointed out, tensorflow 2's Recall metric supports this out of the box for multi-class classification. e.g.

from tensorflow.keras.metrics import Recall

model = ...

model.compile(loss='categorical_crossentropy', metrics=[
    Recall(class_id=0, name='recall_0')
    Recall(class_id=1, name='recall_1')
    Recall(class_id=2, name='recall_2')
])

history = model.fit(...)

plt.plot(history.history['recall_2'])
plt.plot(history.history['val_recall_2'])
plt.show()
like image 771
rob Avatar asked May 30 '19 16:05

rob


2 Answers

We can use classification_report of sklearn and keras Callback to achieve this.

Working code sample (with comments)

import tensorflow as tf
import keras
from tensorflow.python.keras.layers import Dense, Input
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.callbacks import Callback
from sklearn.metrics import recall_score, classification_report
from sklearn.datasets import make_classification
import numpy as np
import matplotlib.pyplot as plt

# Model -- Binary classifier
binary_model = Sequential()
binary_model.add(Dense(16, input_shape=(2,), activation='relu'))
binary_model.add(Dense(8, activation='relu'))
binary_model.add(Dense(1, activation='sigmoid'))
binary_model.compile('adam', loss='binary_crossentropy')

# Model -- Multiclass classifier
multiclass_model = Sequential()
multiclass_model.add(Dense(16, input_shape=(2,), activation='relu'))
multiclass_model.add(Dense(8, activation='relu'))
multiclass_model.add(Dense(3, activation='softmax'))
multiclass_model.compile('adam', loss='categorical_crossentropy')

# callback to find metrics at epoch end
class Metrics(Callback):
    def __init__(self, x, y):
        self.x = x
        self.y = y if (y.ndim == 1 or y.shape[1] == 1) else np.argmax(y, axis=1)
        self.reports = []

    def on_epoch_end(self, epoch, logs={}):
        y_hat = np.asarray(self.model.predict(self.x))
        y_hat = np.where(y_hat > 0.5, 1, 0) if (y_hat.ndim == 1 or y_hat.shape[1] == 1)  else np.argmax(y_hat, axis=1)
        report = classification_report(self.y,y_hat,output_dict=True)
        self.reports.append(report)
        return
   
    # Utility method
    def get(self, metrics, of_class):
        return [report[str(of_class)][metrics] for report in self.reports]
    
# Generate some train data (2 class) and train
x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
                           random_state=1, n_clusters_per_class=1)
metrics_binary = Metrics(x,y)
binary_model.fit(x, y, epochs=30, callbacks=[metrics_binary])

# Generate some train data (3 class) and train
x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
                           random_state=1, n_clusters_per_class=1, n_classes=3)
y = keras.utils.to_categorical(y,3)
metrics_multiclass = Metrics(x,y)
multiclass_model.fit(x, y, epochs=30, callbacks=[metrics_multiclass])

# Plotting 
plt.close('all')
plt.plot(metrics_binary.get('recall',0), label='Class 0 recall') 
plt.plot(metrics_binary.get('recall',1), label='Class 1 recall') 

plt.plot(metrics_binary.get('precision',0), label='Class 0 precision') 
plt.plot(metrics_binary.get('precision',1), label='Class 1 precision') 

plt.plot(metrics_binary.get('f1-score',0), label='Class 0 f1-score') 
plt.plot(metrics_binary.get('f1-score',1), label='Class 1 f1-score') 
plt.legend(loc='lower right')
plt.show()

plt.close('all')
for m in ['recall', 'precision', 'f1-score']:
    for c in [0,1,2]:
        plt.plot(metrics_multiclass.get(m,c), label='Class {0} {1}'.format(c,m))
        
plt.legend(loc='lower right')
plt.show()

Output

enter image description here

enter image description here

Advantages:

  • classification_report provides lots of metrics
  • Can calculate metrics on validation data on train data by passing the same to Metrics constructor.
like image 124
mujjiga Avatar answered Nov 17 '22 10:11

mujjiga


In TF2, tf.keras.metrics.Recall gained a class_id member that enables to do just that. Example using FashionMNIST:

import tensorflow as tf

(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train[..., None].astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train)

input_shape = x_train.shape[1:]
model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=input_shape),
  tf.keras.layers.MaxPool2D(pool_size=2),
  tf.keras.layers.Dropout(0.3),

  tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
  tf.keras.layers.MaxPool2D(pool_size=2),
  tf.keras.layers.Dropout(0.3),

  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(units=256, activation='relu'),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(units=10, activation='softmax')])

model.compile(loss='categorical_crossentropy', optimizer='Adam',
  metrics=[tf.keras.metrics.Recall(class_id=i) for i in range(10)])
model.fit(x_train, y_train, batch_size=128, epochs=50)

In TF 1.13, tf.keras.metric.Recall does not have this class_id argument, but it can be added by subclassing (something that, somewhat suprisingly, seems impossible in the alpha release of TF2).

class Recall(tf.keras.metrics.Recall):

  def __init__(self, *, class_id, **kwargs):
    super().__init__(**kwargs)
    self.class_id= class_id

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = y_true[:, self.class_id]
    y_pred = tf.cast(tf.equal(
      tf.math.argmax(y_pred, axis=-1), self.class_id), dtype=tf.float32)
    return super().update_state(y_true, y_pred, sample_weight)
like image 6
P-Gn Avatar answered Nov 17 '22 12:11

P-Gn