pythonpytorchloss-function

Inconsistent results between PyTorch loss function for `reduction=mean`


In particular, the following code block compares using

nn.CrossEntropyLoss(reduction='mean') with loss_fn = nn.CrossEntropyLoss(reduction='none')

followed by loss.mean().

The results are surprisingly not the same.

import torch
import torch.nn as nn

# Generate random predictions and labels
preds = torch.randn(8, 10, 100)  
labels = torch.randint(high=100, size=(8, 10)) 
# replace some values with -100
labels[torch.rand(labels.shape) < 0.2] = -100

preds, labels = preds.view(-1, 100), labels.view(-1)

def compare_losses(preds, labels):
    # Define loss functions
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    loss_fn_mn = nn.CrossEntropyLoss(reduction='mean')

    # Compute losses
    losses = loss_fn(preds, labels)
    weighted_loss = losses.mean()

    # Compute mean loss using built-in mean reduction
    loss = loss_fn_mn(preds, labels)

    # Print and check if the results are identical
    return torch.isclose(loss, weighted_loss), loss.item(), weighted_loss.item()

compare_losses(preds, labels)

Returns

(tensor(False), 4.997840404510498, 3.748380184173584)

Solution

  • This is due to the ignore_index parameter. You can pass a specific label index value to the loss function that will be ignored (ie for padding tokens and whatnot). The default value for this is -100.

    In your code, you set some label values to -100:

    labels[torch.rand(labels.shape) < 0.2] = -100

    The associated loss values are not included in the mean reduction, but are included when you manually calculate losses.mean().

    If we change the loss calculation to ignore the -100 values like the default mean reduction does, we get matching answers:

    import torch
    import torch.nn as nn
    
    # Generate random predictions and labels
    preds = torch.randn(8, 10, 100)  
    labels = torch.randint(high=100, size=(8, 10)) 
    # replace some values with -100
    labels[torch.rand(labels.shape) < 0.2] = -100
    
    preds, labels = preds.view(-1, 100), labels.view(-1)
    
    def compare_losses(preds, labels):
        # Define loss functions
        loss_fn = nn.CrossEntropyLoss(reduction='none')
        loss_fn_mn = nn.CrossEntropyLoss(reduction='mean')
    
        # Compute losses
        losses = loss_fn(preds, labels)
        weighted_loss = losses[labels != -100].mean() # ignore -100 labels
    
        # Compute mean loss using built-in mean reduction
        loss = loss_fn_mn(preds, labels)
    
        # Print and check if the results are identical
        return torch.isclose(loss, weighted_loss), loss.item(), weighted_loss.item()
    
    compare_losses(preds, labels)
    
    > (tensor(True), 5.122199058532715, 5.122198581695557)