pythontensorflowneural-networkdeep-learningtensorflow-slim

How to calculate top-k in-class accuracies using TensorFlow Slim metrics


I'd like to extend this script such that it is able to evaluate the top-k accuracies per class. I hope it boils down to adding a metric to the following code snippet:

# Define the metrics:
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
    'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
    'Recall_5': slim.metrics.streaming_recall_at_k(
        logits, labels, 5), })

I already followed this comment to add the confusion matrix, which allows me to calculate the top1 in-class accuracies. However, I'm not sure how to get the top-k values as I can't find an appropriate slim metric.

To clarify:


Solution

  • I finally found a solution based on the linked confusion matrix example.

    It's more a tweak than a beautiful solution, but it works: I'm reusing the confusion matrix along with the top_k predictions. The values are stored in the first two columns of the tweaked confusion matrix.

    This is required to create the streaming metric:

    def _get_top_k_per_class_correct_predictions_streaming_metrics(softmax_output, labels, num_classes, top_k):
    """Function to aggregate the correct predictions per class according to the in top_k criteria.
    
    :param softmax_output: The per class probabilities as predicted by the net.
    :param labels: The ground truth data. No(!) one-hot encoding here.
    :param num_classes: Total number of available classes.
    :param top_k:
    :return:
    """
    with tf.name_scope("eval"):
        # create a list with <batch_size> elements. each element is either 1 (prediction correct) or 0 (false)
        batch_correct_prediction_top_k = tf.nn.in_top_k(softmax_output, labels, top_k,
                                                        name="batch_correct_prediction_top_{}".format(top_k))
    
        # the above output is boolean, but we need integers to sum them up
        batch_correct_prediction_top_k = tf.cast(batch_correct_prediction_top_k, tf.int32)
    
        # use the confusion matrix implementation to get the desired results
        # we actually need only the first two columns of the returned matrix.
        batch_correct_prediction_top_k_matrix = tf.confusion_matrix(labels, batch_correct_prediction_top_k,
                                                                    num_classes=num_classes,
                                                                    name='batch_correct_prediction_top{}_matrix'.format(
                                                                        top_k))
    
        correct_prediction_top_k_matrix = _create_local_var('correct_prediction_top{}_matrix'.format(top_k),
                                                            shape=[num_classes,
                                                                   num_classes],
                                                            dtype=tf.int32)
        # Create the update op for doing a "+=" accumulation on the batch
        correct_prediction_top_k_matrix_update = correct_prediction_top_k_matrix.assign(
            correct_prediction_top_k_matrix + batch_correct_prediction_top_k_matrix)
    
    return correct_prediction_top_k_matrix, correct_prediction_top_k_matrix_update
    

    as well as:

    def _create_local_var(name, shape, collections=None, validate_shape=True,
                      dtype=tf.float32):
    """Creates a new local variable.
    
    This method is required to get the confusion matrix.
    see https://github.com/tensorflow/models/issues/1286#issuecomment-317205632
    
    Args:
      name: The name of the new or existing variable.
      shape: Shape of the new or existing variable.
      collections: A list of collection names to which the Variable will be added.
      validate_shape: Whether to validate the shape of the variable.
      dtype: Data type of the variables.
    Returns:
      The created variable.
    """
    # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
    collections = list(collections or [])
    collections += [tf.GraphKeys.LOCAL_VARIABLES]
    return variables.Variable(
        initial_value=tf.zeros(shape, dtype=dtype),
        name=name,
        trainable=False,
        collections=collections,
        validate_shape=validate_shape)
    

    Add the new metric to the slim config and evaluate:

    # Define the metrics:
    softmax_output = tf.nn.softmax(logits, name="softmax_for_evaluation")
    names_to_values, names_to_updates =    slim.metrics.aggregate_metric_map({
                [..]
                KEY_ACCURACY5_PER_CLASS_KEY_MATRIX: _get_top_k_per_class_correct_predictions_streaming_metrics(
                    softmax_output, labels, self._dataset.num_classes - labels_offset, 5),
                [..]
            })
    
    # evaluate
    results = slim.evaluation.evaluate_once([..])
    

    Finally, you can use the additional matrix to calculate the top_k accuracies per class:

        def _calc_in_class_accuracy_top_k(self, results):
        """Calculate the top_k accuracies per class.
    
        :param results:
        :return:
        """
    
        # use a tweaked confusion matrix to calculate the in-class accuracy5
        # rows represent the real labels
        # the 1-th column contains the number of times that the associated class was correctly classified as one of the
        # top_k results. 0-th column contains the number of failed predictions. The sum is the total number of provided
        # samples per class.
        matrix_top_k = results[KEY_ACCURACY5_PER_CLASS_KEY_MATRIX]
    
        n_classes = matrix_top_k.shape[0]
        in_class_accuracy_top_k_per_class = np.zeros(n_classes, np.float)
        for id in range(n_classes):
            correct_top_k = matrix_top_k[id][1]
            total_occurrences = np.sum(matrix_top_k[id])  # this many samples of the current class exist in total
    
            # top_k accuracy
            in_class_accuracy_top_k_per_class[id] = correct_top_k
            if total_occurrences > 0:
                in_class_accuracy_top_k_per_class[id] /= total_occurrences
    
            # convert to floats
            in_class_accuracy_top_k_per_class[id] = float(in_class_accuracy_top_k_per_class[id])
    
        return in_class_accuracy_top_k_per_class