I want to log the median loss every epoch in Pytorch Lightning. I tried self.log(..., reduce_fx=torch.median)
which gave me error lightning_fabric.utilities.exceptions.MisconfigurationException: Only `self.log(..., reduce_fx={min,max,mean,sum})` are supported. If you need a custom reduction, please log a `torchmetrics.Metric` instance instead. Found: <function median at 0x7f53bbd59fc0>
So I implement a subclass of torchmetrics.Metric, and pass that as reduce_fx:
import torch
from torchmetrics import Metric
class MedianMetric(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("median", default=torch.Tensor(0))
def update(self, batch_losses):
self.median = torch.median(self._input_format(batch_losses))
def compute(self):
return self.median.float()
def __name__(self):
return "median"
in my LightningModule:
class MyModel(pl.LightningModule):
def __init__(self, ...):
self._red = MedianMetric()
...
...
def training_step(self, batch, batch_idx):
loss, _ = self.calculate_metrics_and_loss(batch, batch_idx, metric=False)
self.log("train_loss", loss, on_epoch=True, prog_bar=True, reduce_fx=self._red)
return loss
Now I get TypeError: MedianMetric.update() takes 2 positional arguments but 3 were given
. I think it wants me to put preds and inputs as the arguments to update, but then I'm re-calculating the loss once in training_step and then a second time in MedianMetric.compute, which doesn't seem right. Is there an easy way to keep track of median loss per epoch using Lightning?
I could not figure out a way to do it with self.log in training_step. However you can aggregate them yourself.
def MyModel(pl.LightningModule):
def __init__(self, ...):
...
self._training_batch_losses = []
...
def training_step(self, batch, batch_idx):
# get batch loss
...
self._training_batch_losses.append(batch_loss)
return batch_loss
def on_train_epoch_end(self):
epoch_median = torch.stack(self._training_batch_losses).median()
self.log("train_loss_epoch", epoch_median)
self._training_batch_losses.clear()
It works for validation as well.