pytorchpytorch-lightningtorchmetrics

How to use median as the reduction function for loss in Pytorch Lightning Trainer?


I want to log the median loss every epoch in Pytorch Lightning. I tried self.log(..., reduce_fx=torch.median) which gave me error lightning_fabric.utilities.exceptions.MisconfigurationException: Only `self.log(..., reduce_fx={min,max,mean,sum})` are supported. If you need a custom reduction, please log a `torchmetrics.Metric` instance instead. Found: <function median at 0x7f53bbd59fc0>

So I implement a subclass of torchmetrics.Metric, and pass that as reduce_fx:

import torch
from torchmetrics import Metric

class MedianMetric(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("median", default=torch.Tensor(0))
    
    def update(self, batch_losses):
        self.median = torch.median(self._input_format(batch_losses))

    def compute(self):
        return self.median.float()
    
    def __name__(self):
        return "median"

in my LightningModule:

class MyModel(pl.LightningModule):
    def __init__(self, ...):
        self._red = MedianMetric()
        ...
    ...
    def training_step(self, batch, batch_idx):
        loss, _ = self.calculate_metrics_and_loss(batch, batch_idx, metric=False)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, reduce_fx=self._red)
        return loss

Now I get TypeError: MedianMetric.update() takes 2 positional arguments but 3 were given. I think it wants me to put preds and inputs as the arguments to update, but then I'm re-calculating the loss once in training_step and then a second time in MedianMetric.compute, which doesn't seem right. Is there an easy way to keep track of median loss per epoch using Lightning?


Solution

  • I could not figure out a way to do it with self.log in training_step. However you can aggregate them yourself.

    def MyModel(pl.LightningModule):
        def __init__(self, ...):
            ...
            self._training_batch_losses = []
    
        ...
    
        def training_step(self, batch, batch_idx):
            # get batch loss
            ...
            self._training_batch_losses.append(batch_loss)
            return batch_loss
    
        def on_train_epoch_end(self):
            epoch_median = torch.stack(self._training_batch_losses).median()
            self.log("train_loss_epoch", epoch_median)
            self._training_batch_losses.clear()
    

    It works for validation as well.