allennlp

Writing custom metrics in allennlp


I'm writing down my first allennlp project to detect specific spans in newspaper articles. I was able to have it train on my dataset. The loss computed with cross entropy seems to decrease correctly, but I'm having some issues with my metric. I wrote a custom metric which is supposed to give an estimate of how accurate my model predicts spans according to some ground truth spans. The problem is that right now, our metric doesn't seem to update correctly even though the loss is decreasing.

I'm not sure how to tackle the problem and guess my questions are the following:

  1. What is the exact use of the reset() function in the Metric class ?
  2. Apart from writing the __call__(), get_metric() and reset() function, are there other things to watch out for?

Below is a snapshot of my custom Metric class in case you need it.

class SpanIdenficationMetric(Metric):
    def __init__(self) -> None:
        self._s_cardinality = 0 # S: model predicted spans
        self._t_cardinality = 0 # T: article gold spans
        self._s_sum = 0
        self._t_sum = 0
        
    def reset(self) -> None:
        self._s_cardinality = 0
        self._t_cardinality = 0
        self._s_sum = 0
        self._t_sum = 0
            
    def __call__(self, prop_spans: torch.Tensor, gold_spans: torch.Tensor, mask: Optional[torch.BoolTensor] = None):
        for i, article_spans in enumerate(prop_spans):
            if article_spans.numel() == 0:
                continue
            article_gold_spans = gold_spans[i]
            merged_prop_spans = self._merge_intervals(article_spans)
            self._s_cardinality += merged_prop_spans.size(dim=0)
            self._t_cardinality += article_gold_spans.size(dim=0)
            for combination in itertools.product(merged_prop_spans, article_gold_spans):
                sspan = combination[0]
                tspan = combination[1]
                self._s_sum += self._c_function(sspan, tspan, sspan[1].item() - sspan[0].item() + 1)
                self._t_sum += self._c_function(sspan, tspan, tspan[1].item() - tspan[0].item() + 1)

    def get_metric(self, reset: bool = False):
        precision = 0
        recall = 0
        if self._s_cardinality != 0:
            precision = self._s_sum / self._s_cardinality
        if self._t_cardinality != 0:
            recall = self._t_sum / self._t_cardinality
        if reset:
            self.reset()
        return { "si-metric" : (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0 }

def _c_function(self, s, t, h): {}
def _intersect(self, s, t): {}
def _merge_intervals(self, prop_spans): {}

Thank you in advance. Cheers.


Solution

  • During training, the trainer will call the metric (using Metric.__call__()) with the results from every batch. The metric is supposed to update its internal state when this happens. The trainer expects to get the current value(s) of the metric when it calls Metric.get_metric(). Metric.reset() has to reset the metric into a state as if it had never been called before. When get_metric() gets called with reset = True, it's expected to reset the metric as well.

    From what I can tell, your code does all these things correctly. Your code will not run correctly in a distributed setting, but if you are not training on multiple GPUs, that's not a problem.

    What you're doing is similar to the SQuAD metric: https://github.com/allenai/allennlp-models/blob/main/allennlp_models/rc/metrics/squad_em_and_f1.py The SQuAD metric goes out of its way to call the original SQuAD evaluation code, so it's a little more complicated than what you would want, but maybe you can adapt it? The main difference would be that you are calculating F scores across the whole dataset, while SQuAD calculates them per-document, and then averages across documents.

    Finally, you can write a simple test for your metric, similar to the SQuAD test: https://github.com/allenai/allennlp-models/blob/main/tests/rc/metrics/squad_em_and_f1_test.py That might help narrow down where the problem is.