Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to implement custom metric in keras?

I get this error :

sum() got an unexpected keyword argument 'out'

when I run this code:

import pandas as pd, numpy as np import keras from keras.layers.core import Dense, Activation from keras.models import Sequential  def AUC(y_true,y_pred):     not_y_pred=np.logical_not(y_pred)     y_int1=y_true*y_pred     y_int0=np.logical_not(y_true)*not_y_pred     TP=np.sum(y_pred*y_int1)     FP=np.sum(y_pred)-TP     TN=np.sum(not_y_pred*y_int0)     FN=np.sum(not_y_pred)-TN     TPR=np.float(TP)/(TP+FN)     FPR=np.float(FP)/(FP+TN)     return((1+TPR-FPR)/2)  # Input datasets  train_df = pd.DataFrame(np.random.rand(91,1000)) train_df.iloc[:,-2]=(train_df.iloc[:,-2]>0.8)*1   model = Sequential() model.add(Dense(output_dim=60, input_dim=91, init="glorot_uniform")) model.add(Activation("sigmoid")) model.add(Dense(output_dim=1, input_dim=60, init="glorot_uniform")) model.add(Activation("sigmoid"))  model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=[AUC])   train_df.iloc[:,-1]=np.ones(train_df.shape[0]) #bias X=train_df.iloc[:,:-1].values Y=train_df.iloc[:,-1].values print X.shape,Y.shape  model.fit(X, Y, batch_size=50,show_accuracy = False, verbose = 1) 

Is it possible to implement a custom metric aside from doing a loop on batches and editing the source code?

like image 292
Philippe C Avatar asked Jun 06 '16 12:06

Philippe C


People also ask

How do I choose keras metrics?

Keras allows you to list the metrics to monitor during the training of your model. You can do this by specifying the “metrics” argument and providing a list of function names (or function name aliases) to the compile() function on your model.

What is metrics keras?

A metric is a function that is used to judge the performance of your model. Metric functions are similar to loss functions, except that the results from evaluating a metric are not used when training the model. Note that you may use any loss function as a metric.

What is keras accuracy metric?

Accuracy(name="accuracy", dtype=None) Calculates how often predictions equal labels. This metric creates two local variables, total and count that are used to compute the frequency with which y_pred matches y_true .


1 Answers

Here I'm answering to OP's topic question rather than his exact problem. I'm doing this as the question shows up in the top when I google the topic problem.

You can implement a custom metric in two ways.

  1. As mentioned in Keras docu.

    import keras.backend as K  def mean_pred(y_true, y_pred):     return K.mean(y_pred)  model.compile(optimizer='sgd',           loss='binary_crossentropy',           metrics=['accuracy', mean_pred]) 

    But here you have to remember as mentioned in Marcin Możejko's answer that y_true and y_pred are tensors. So in order to correctly calculate the metric you need to use keras.backend functionality. Please look at this SO question for details How to calculate F1 Macro in Keras?

  2. Or you can implement it in a hacky way as mentioned in Keras GH issue. For that you need to use callbacks argument of model.fit.

    import keras as keras import numpy as np from keras.optimizers import SGD from sklearn.metrics import roc_auc_score  model = keras.models.Sequential() # ... sgd = SGD(lr=0.001, momentum=0.9) model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])   class Metrics(keras.callbacks.Callback):     def on_train_begin(self, logs={}):         self._data = []      def on_epoch_end(self, batch, logs={}):         X_val, y_val = self.validation_data[0], self.validation_data[1]         y_predict = np.asarray(model.predict(X_val))          y_val = np.argmax(y_val, axis=1)         y_predict = np.argmax(y_predict, axis=1)          self._data.append({             'val_rocauc': roc_auc_score(y_val, y_predict),         })         return      def get_data(self):         return self._data  metrics = Metrics() history = model.fit(X_train, y_train, epochs=100, validation_data=(X_val, y_val), callbacks=[metrics]) metrics.get_data() 
like image 167
vogdb Avatar answered Oct 03 '22 05:10

vogdb