Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow plot tf.metrics.precision_at_thresholds in Tensorboard through eval_metric_ops

tf.metrics.precision_at_thresholds() takes three arguments: labels, predictions, thresholds where thresholds is a a python list or tuple of thresholds between [0,1]. The function then returns "A float Tensor of shape [len(thresholds)]" which is problematic for automatically plotting eval_metric_ops to tensorboard (as I believe they are expected a scalar). The values will print to the console just fine, but I would also like to plot the values in tensorboard. Is there any adjustment that can be made to be able to plot the value in tensorboard?

like image 767
reese0106 Avatar asked Sep 15 '17 14:09

reese0106


1 Answers

I found it really strange that TensorFlow (as of 1.8) does not offer a summary function for metrics like tf.metrics.precision_at_thresholds (in general tf.metrics.*_at_thresholds). The following is a minimal working example:

def summarize_metrics(metrics_update_ops):
    for metric_op in metric_ops:
        shape = metric_op.shape.as_list()
        if shape:  # this is a metric created with any of tf.metrics.*_at_thresholds
            summary_components = tf.split(metric_op, shape[0])
            for i, summary_component in enumerate(summary_components):
                tf.summary.scalar(
                    name='{op_name}_{i}'.format(op_name=summary_components.name, i=i),
                    tensor=tf.squeeze(summary_component, axis=[0])
                )
        else:  # this already is a scalar metric operator
            tf.summary.scalar(name=summary_components.name, tensor=metric_op)

precision, precision_op = tf.metrics.precision_at_thresholds(labels=labels,
                                                             predictions=predictions,
                                                             thresholds=threshold)
summarize_metrics([precision_op])

The downside of this approach, in general, is that notion of whatever thresholds you used to create the metric in the first place, is lost when summarizing them. I came up with a slightly more complex, but easier to use solution that uses collections to store all metric update operators.

# Create a metric and let it add the vars and update operators to the specified collections
thresholds = [0.5, 0.7]
tf.metrics.recall_at_thresholds(
    labels=labels, predictions=predictions, thresholds=thresholds,
    metrics_collections='metrics_vars', metrics_update_ops='metrics_update_ops'
)

# Anywhere else call the summary method I provide in the Gist at the bottom [1]
# Because we provide a mapping of a scope pattern to the thresholds, we can
# assign them later
summarize_metrics(list_lookup={'recall_at_thresholds': thresholds})

The implementation in the Gist [1] below also supports options for formatting the sometimes cryptic names of the metrics nicely.

[1]: https://gist.github.com/patzm/961dcdcafbf3c253a056807c56604628

How this could look like: Imgur

like image 135
patzm Avatar answered Nov 15 '22 04:11

patzm