I'm trying to define the F1-score as a custom metric in TensorFlow for a DNNClassifier
. To do that, I wrote a function
def metric_fn(predictions=[], labels=[], weights=[]):
P, _ = tf.contrib.metrics.streaming_precision(predictions, labels)
R, _ = tf.contrib.metrics.streaming_recall(predictions, labels)
if P + R == 0:
return 0
return 2*(P*R)/(P+R)
that uses streaming_precision
and streaming_recall
from TensorFlow to calulate the F1 score. After that I made a new entry to the validation_metrics:
validation_metrics = {
"accuracy":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_accuracy,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"precision":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_precision,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"recall":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_recall,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"f1score":
tf.contrib.learn.MetricSpec(
metric_fn=metric_fn,
prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
}
However, although I get correct precision and recall values, f1score
is always nan
:
INFO:tensorflow:Saving dict for global step 151: accuracy = 0.982456, accuracy/baseline_label_mean = 0.397661, accuracy/threshold_0.500000_mean = 0.982456, auc = 0.982867, f1score = nan, global_step = 151, labels/actual_label_mean = 0.397661, labels/prediction_mean = 0.406118, loss = 0.310612, precision = 0.971014, precision/positive_threshold_0.500000_mean = 0.971014, recall = 0.985294, recall/positive_threshold_0.500000_mean = 0.985294
Something is wrong with my metric_fn
, but I can't figure it out.
The values P
and R
obtained by metric_fn
are of the form
Tensor("precision/value:0", shape=(), dtype=float32)
. I find this a bit strange. I was expecting a scalar tensor.
Any help is appreciated.
I think the problem may come from the fact that the streaming metrics you use within your metric_fn
do not get any update.
Try the following (I also included minor modifications to my taste):
def metric_fn(predictions=None, labels=None, weights=None):
P, update_op1 = tf.contrib.metrics.streaming_precision(predictions, labels)
R, update_op2 = tf.contrib.metrics.streaming_recall(predictions, labels)
eps = 1e-5;
return (2*(P*R)/(P+R+eps), tf.group(update_op1, update_op2))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With