machine-learningpytorchcross-entropy

PyTorch CrossEntropyLoss documentation example crashes


To make sure I'm using PyTorch CrossEntropyLoss correctly, I'm trying the examples from the documentation: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

However, the first example (target with class indices) doesn't seem to update the weights, and the second example (target with class probabilities) crashes.

Focusing on the second, being the more obvious kind of error, the complete program I'm running is

import torch
from torch import nn

# Example of target with class probabilities
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)

And the error message is

Traceback (most recent call last):
  File "crossentropy-probabilities.py", line 9, in <module>
    output = loss(input, target)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\modules\loss.py", line 948, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\functional.py", line 2422, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\functional.py", line 2218, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1D target tensor expected, multi-target not supported 

Is the documentation in error, or am I missing something obvious?


Solution

  • You are likely using a PyTorch version < 1.10.

    Depending on the version of PyTorch you are using this feature might not be available. For version 1.10 and upwards, the target tensor can be provided either in dense format (with class indices) or as a probability map (soft labels).

    You can compare the documentation page of nn.CrossEntropy: from 1.9.1 to 1.10.