pytorchcross-entropy

How to do PyTorch F.cross_entropy?


I have output with dimension batch_size x 14 x 100 (14 object x 100 classes). I want to do cross-entropy loss with ground truth indices provided with dimension batch_size x 14. However, when I use torch.functional.cross_entropy, I get the error message that says Expected target size [15, 100], got [15, 14]. Does anyone know what the reason is? Thank you in advance


Solution

  • See the cross entropy documentation

    For higher dim inputs, the inputs and targets are expected to be of size (N, C, d_1, ... d_k) and (N, d_1, ... d_k) where N is the batch size and C is the number of classes.

    Your output should be of shape batch_size x 100 x 14 rather than batch_size x 14 x 100