Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Custom metrics with tf.estimator

I want tensorflow to calculate the coefficient of determination (R squared) during evaluation of my estimator. I tried to implement it in the following way loosly based on the implementation of the official metrics:

def r_squared(labels, predictions, weights=None,
              metrics_collections=None,
              updates_collections=None,
              name=None):

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels)))
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
    r_sq = 1 - tf.div(unexplained_error, total_error)

    # update_rsq_op = ?

    if metrics_collections:
        ops.add_to_collections(metrics_collections, r_sq)

    # if updates_collections:
    #     ops.add_to_collections(updates_collections, update_rsq_op)

    return r_sq #, update_rsq_op

Then, I use this function as a metric in the EstimatorSpec:

estim_specs = tf.estimator.EstimatorSpec(
    ...
    eval_metric_ops={
        'r_squared': r_squared(labels, predictions),
        ...
    })

However, this fails since my implementation of R squared doesn't return an update_op.

TypeError: Values of eval_metric_ops must be (metric_value, update_op) tuples, given: Tensor("sub_4:0", dtype=float64) for key: r_squared

Now I wonder, what exactly is the update_op supposed to do? Do I actually need to implement an update_op or can I somehow create some kind of dummy update_op? And if it is necessary, how would I implement it?

like image 858
msteininger Avatar asked Dec 11 '17 13:12

msteininger


1 Answers

Ok, so I was able to figure it out. I can wrap my metric in a mean metric and use its update_op. This seems to work for me.

def r_squared(labels, predictions, weights=None,
              metrics_collections=None,
              updates_collections=None,
              name=None):

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels)))
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
    r_sq = 1 - tf.div(unexplained_error, total_error)

    m_r_sq, update_rsq_op = tf.metrics.mean(r_sq)

    if metrics_collections:
        ops.add_to_collections(metrics_collections, m_r_sq)

    if updates_collections:
        ops.add_to_collections(updates_collections, update_rsq_op)

    return m_r_sq, update_rsq_op
like image 53
msteininger Avatar answered Oct 14 '22 16:10

msteininger