Hi I'm facing an issue in gathering all the losses and predictions in multi gpu scenario. I'm using pytorch lightning 2.0.4 and deepspeed, distributed strategy - deepspeed_stage_2.
I'm adding my skeleton code here for reference.
def __init__(self):
self.batch_train_preds = []
self.batch_train_losses = []
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
# Model Step
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=train_labels)
train_preds = torch.argmax(outputs.logits, dim=-1)
return {'loss': outputs[0],
'train_preds': train_preds}
def on_train_batch_end(self, outputs, batch, batch_idx):
# aggregate metrics or outputs at batch level
train_batch_loss = outputs["loss"].mean()
train_batch_preds = torch.cat(outputs["train_preds"])
return {'train_batch_loss': train_batch_loss,
'train_batch_preds': train_batch_preds
def on_train_epoch_end(self) -> None:
# Aggregate epoch level training metrics
epoch_train_preds = torch.cat(self.batch_train_preds)
epoch_train_loss = np.mean(self.batch_train_losses)
self.logger.log_metrics({"epoch_train_loss": epoch_train_loss})
In the above code block, I'm trying to combine all the predictions into a single tensor at the end of the epoch by tracking each batch in a global list (defined at init). but in multi gpu training, I faced an error with concatination as each gpu is treating the batch in it's own device and I couldn't combine the results in a single global list.
What should I be doing in on_train_batch_end or on_train_epoch_end or in training_step in order to combine the results across all the gpus into a list created in my init because I want to calculate some additional metrics(precision, recall etc) during ON_*_EPOCH_END() function in my train, validation, test
(validation and test are exactly similar to my 3 training functions above i.e combining losses and predictions).
I have come across all_gather but it is being called across all devices(gpus) but comibining the results which I wanted.
Now the question is how do I use only one of the device's output from all_gather. A code snippet would be very much helpful.
documentation suggests to use all_gather
. Moreover, you do not need to manually accumulate the loss, just log it with self.log(..., epoch=True)
to let lightning accumulate and log it correctly:
class MyLightningModule(LightningModule):
def __init__(self):
self.batch_train_preds = []
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
# Model Step
outputs = self.model(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
loss = outputs[0]
train_preds = torch.argmax(outputs.logits, dim=-1)
self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
return loss
def on_train_epoch_end(self) -> None:
# Aggregate epoch level training metrics
epoch_train_preds = torch.cat(self.batch_train_preds, dim=0)
# the following will stack predictions from all the distributed processes on dim=0
epoch_train_preds = self.all_gather(epoch_train_preds)
# reshape to (dataset_size, *other_dims)
new_batch_size = self.trainer.world_size() * epoch_train_preds.shape[0]
epoch_train_preds = epoch_train_preds.view(new_batch_size, *epoch_train_preds.shape[1:])
# compute here your metrics over `epoch_train_preds`
self.batch_train_preds.clear() # free memory
If you want to compute the metric only on a single process, protect the metric computation with if self.trainer.global_rank == 0:
I also suggest to take a look at torchmetrics, which enables automatic synchronisation of metrics in distributed setting with a few lines of code.
Additionally, I've written a framework for easy training and testing of several Transformer models for NLP.
from torchmetrics.classification import BinaryAccuracy
from lightning.pytorch import LightningModule
class MyLightningModule(LightningModule):
def __init__(self):
self.train_accuracy = BinaryAccuracy()
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
# Model Step
outputs = self.model(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
loss = outputs[0]
train_preds = torch.argmax(outputs.logits, dim=-1)
self.train_accuracy(train_preds, labels) # updates the metric internal state with predictions and labels
self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
self.log('train/acc', self.train_accuracy, on_step=True, on_epoch=True, sync_dist=True)
return loss
def on_train_epoch_end(self) -> None:
pass # no need to reset the metric as lightning will take care of that after each epoch