pythondeep-learningpytorchmse

PyTorch calculate MSE and MAE


I would like to calculate the MSE and MAE of the model below. The model is calculating the MSE after each Epoch. What do I need to do to get the overall MSE value, please? Can I use the same code to calculate the MAE? Many Thanks in advance

model.eval()
for images, paths in tqdm(loader_test):
    images = images.to(device)
    targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
    targets = targets.float().to(device)

    # forward pass:
    output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
    preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)

    # logging:
    loss = torch.mean((preds - targets)**2)
    count_error = torch.abs(preds - targets).mean()
    mean_test_error += count_error
    writer.add_scalar('test_loss', loss.item(), global_step=global_step)
    writer.add_scalar('test_count_error', count_error.item(), global_step=global_step)
    
    global_step += 1

average_accuracy = 0 
mean_test_error = mean_test_error / len(loader_test)
writer.add_scalar('mean_test_error', mean_test_error.item(), global_step=global_step)
average_accuracy += mean_test_error
average_accuracy = average_accuracy /len(loader_test)
print("Average accuracy: %f" % average_accuracy)
print("Test count error: %f" % mean_test_error)
if mean_test_error < best_test_error:
    best_test_error = mean_test_error
    torch.save({'state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'globalStep':global_step,
                'train_paths':dataset_train.files,
                'test_paths':dataset_test.files},checkpoint_path)

Solution

  • First of all, you would want to keep your batch size as 1 during test phase for simplicity.

    This maybe task specific, but calculation of MAE and MSE for a heat map regression model are done based on the following equations:

    MAE

    MSE

    This means that in your code, you should change the lines where you calculate MAE as following

    error = torch.abs(preds - targets).sum().data
    squared_error = ((preds - targets)*(preds - targets)).sum().data
    running_mae += error
    running_mse += squared_error
    

    and later, after the epoch ends,

    rmse = math.sqrt(running_mse/len(loader_test))
    mae = running_mae/len(loader_test)