Is there a standard/efficient way in Pytorch
to handle cross-entropy loss for a classification problem where the number of classes depends on the sample?
Example: In a batch of size 3, I have:
logits1 = [0.2, 0.2, 0.6], labels1 = [0, 0, 1]
logits2 = [0.4, 0.1, 0.1, 0.4], labels2 = [1, 0, 0, 0]
logits3 = [0.2, 0.8], labels3 = [1, 0]
I am looking for the right way to compute cross_entropy_loss(logits,labels)
on this batch.
Cross entropy loss is used when a single output class is being predicted. When you say the number of classes depends on the sample
, I assume you mean a situation where the number of logits is different for each sample is different, but we are still in a cross entropy situation where each sample has one correct class.
In this case you can simply pad the samples with -inf
which will be ignored in the cross entropy loss calculation.
# start with our sequences
sequences = [
[0.2, 0.2, 0.6],
[0.4, 0.1, 0.1, 0.4],
[0.2, 0.8]
sequences = [torch.tensor(i) for i in sequences]
# represent labels as class int values
# this is required for pytorch's crossentropyloss
labels = torch.tensor([2, 0, 0]).long()
# pack sequences into a square batch
# fill padding values with `-inf`
padding_value = float('-inf')
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=padding_value)
# create `CrossEntropyLoss` with `reduction='none'`
# this makes the loss return the value for each input (ie no averaging)
# so we can compare values
loss = nn.CrossEntropyLoss(reduction='none')
# compute loss on individual sequences without padding
l1 = torch.stack([loss(sequences[i], labels[i]) for i in range(labels.shape[0])])
# compute loss on padded sequences
l2 = loss(sequences_padded, labels)
# check values match
assert torch.allclose(l1, l2)
This works because cross entropy computes exp(i)
for all values in the input, and exp(-inf)
evals to 0. Because of this, the padding values have no impact on the output loss.