I've set up a basic UNET model. When using a function to train the model directly, it optimizes fine. However, when using a similar loop in pytorch lightning with the train step defined, the loss does not change from the original value. I took out the zero_grad/backward/step bits based on this tutorial. What am I doing wrong?
# Optimizes well
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to('cuda',dtype=torch.float), y.to('cuda',dtype=torch.float)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Using this as a function inside the UNet class, which I feed to pytorch_lightning.Trainer.
# Loss does not update from initial value. Model predictions do not improve.
def training_step(self, batch, batch_idx):
X,y = batch
X, y = X.to(self.device,dtype=torch.float), y.to(self.device,dtype=torch.float)
# Compute prediction error
pred = self.forward(X)
loss = self.loss_fn(pred, y)
self.log("train_loss", loss)
return loss
This issue was caused by the following line in the model class:
def configure_optimizers(self):
return super().configure_optimizers()
One of the threads online recommended having this together with training_step and train_dataloader as a minimum set of methods to run pytorch lightning. However, in fact this line interferes with optimization - perhaps, the same batch is loaded every time so that loss does not improve. Simply deleting this method fixes the issue. LightningModule.fit takes in a data loader and uses that to pass batches to training_step.