pythontensorflowtensorflow-estimator

Custom metrics with tf.estimator


I want tensorflow to calculate the coefficient of determination (R squared) during evaluation of my estimator. I tried to implement it in the following way loosly based on the implementation of the official metrics:

def r_squared(labels, predictions, weights=None,
              metrics_collections=None,
              updates_collections=None,
              name=None):

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels)))
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
    r_sq = 1 - tf.div(unexplained_error, total_error)

    # update_rsq_op = ?

    if metrics_collections:
        ops.add_to_collections(metrics_collections, r_sq)

    # if updates_collections:
    #     ops.add_to_collections(updates_collections, update_rsq_op)

    return r_sq #, update_rsq_op

Then, I use this function as a metric in the EstimatorSpec:

estim_specs = tf.estimator.EstimatorSpec(
    ...
    eval_metric_ops={
        'r_squared': r_squared(labels, predictions),
        ...
    })

However, this fails since my implementation of R squared doesn't return an update_op.

TypeError: Values of eval_metric_ops must be (metric_value, update_op) tuples, given: Tensor("sub_4:0", dtype=float64) for key: r_squared

Now I wonder, what exactly is the update_op supposed to do? Do I actually need to implement an update_op or can I somehow create some kind of dummy update_op? And if it is necessary, how would I implement it?


Solution

  • Ok, so I was able to figure it out. I can wrap my metric in a mean metric and use its update_op. This seems to work for me.

    def r_squared(labels, predictions, weights=None,
                  metrics_collections=None,
                  updates_collections=None,
                  name=None):
    
        total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels)))
        unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
        r_sq = 1 - tf.div(unexplained_error, total_error)
    
        m_r_sq, update_rsq_op = tf.metrics.mean(r_sq)
    
        if metrics_collections:
            ops.add_to_collections(metrics_collections, m_r_sq)
    
        if updates_collections:
            ops.add_to_collections(updates_collections, update_rsq_op)
    
        return m_r_sq, update_rsq_op