In particular, the following code block compares using
nn.CrossEntropyLoss(reduction='mean')
with loss_fn = nn.CrossEntropyLoss(reduction='none')
followed by loss.mean()
.
The results are surprisingly not the same.
import torch
import torch.nn as nn
# Generate random predictions and labels
preds = torch.randn(8, 10, 100)
labels = torch.randint(high=100, size=(8, 10))
# replace some values with -100
labels[torch.rand(labels.shape) < 0.2] = -100
preds, labels = preds.view(-1, 100), labels.view(-1)
def compare_losses(preds, labels):
# Define loss functions
loss_fn = nn.CrossEntropyLoss(reduction='none')
loss_fn_mn = nn.CrossEntropyLoss(reduction='mean')
# Compute losses
losses = loss_fn(preds, labels)
weighted_loss = losses.mean()
# Compute mean loss using built-in mean reduction
loss = loss_fn_mn(preds, labels)
# Print and check if the results are identical
return torch.isclose(loss, weighted_loss), loss.item(), weighted_loss.item()
compare_losses(preds, labels)
Returns
(tensor(False), 4.997840404510498, 3.748380184173584)
This is due to the ignore_index
parameter. You can pass a specific label index value to the loss function that will be ignored (ie for padding tokens and whatnot). The default value for this is -100
.
In your code, you set some label values to -100
:
labels[torch.rand(labels.shape) < 0.2] = -100
The associated loss values are not included in the mean reduction, but are included when you manually calculate losses.mean()
.
If we change the loss calculation to ignore the -100
values like the default mean reduction does, we get matching answers:
import torch
import torch.nn as nn
# Generate random predictions and labels
preds = torch.randn(8, 10, 100)
labels = torch.randint(high=100, size=(8, 10))
# replace some values with -100
labels[torch.rand(labels.shape) < 0.2] = -100
preds, labels = preds.view(-1, 100), labels.view(-1)
def compare_losses(preds, labels):
# Define loss functions
loss_fn = nn.CrossEntropyLoss(reduction='none')
loss_fn_mn = nn.CrossEntropyLoss(reduction='mean')
# Compute losses
losses = loss_fn(preds, labels)
weighted_loss = losses[labels != -100].mean() # ignore -100 labels
# Compute mean loss using built-in mean reduction
loss = loss_fn_mn(preds, labels)
# Print and check if the results are identical
return torch.isclose(loss, weighted_loss), loss.item(), weighted_loss.item()
compare_losses(preds, labels)
> (tensor(True), 5.122199058532715, 5.122198581695557)