I am using pytorch to train my CNN network. I want to plot my training and validation loss curves to visulize the model performance. How can I plot two curves?
I have below code
# create a function (this my favorite choice)
def RMSELoss(predicted,target):
return torch.sqrt(torch.mean((predicted-target)**2))
criterion = RMSELoss
# loss = torch.sqrt(criterion(x, y))
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 300
n_total_steps = len(train_dataset)
trainingEpoch_loss = []
validationEpoch_loss = []
for epoch in range(epochs):
step_loss = []
model.train()
for i, data in enumerate(train_dataset):
feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
# Clear the gradients
optimizer.zero_grad()
# Forward Pass
outputs = model(feature)
# Find the Loss
training_loss = criterion(outputs, target)
# Calculate gradients
training_loss.backward()
# Update Weights
optimizer.step()
# Calculate Loss
step_loss.append(training_loss.item())
if (i+1) % 1 == 0:
print (f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {training_loss.item():.4f}')
trainingEpoch_loss.append(np.array(step_loss).mean())
model.eval() # Optional when not using Model Specific layer
for i, data in enumerate(val_dataset):
validationStep_loss = []
feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
# Forward Pass
outputs = model(feature)
# Find the Loss
validation_loss = criterion(outputs, target)
# Calculate Loss
validationStep_loss.append(validation_loss.item())
validationEpoch_loss.append(np.array(validationStep_loss).mean())
Can you let me know if i am doing right or not? Also please let me know how to plot training and validation loss?
you are correct to collect your epoch losses in trainingEpoch_loss
and validationEpoch_loss
lists.
Now, after the training, add code to plot the losses:
from matplotlib import pyplot as plt
plt.plot(trainingEpoch_loss, label='train_loss')
plt.plot(validationEpoch_loss,label='val_loss')
plt.legend()
plt.show
read matplotlib docs for more fancly plot features.