I am trying to compute cross_entropy loss manually in Pytorch for an encoder-decoder model.
I used the code posted here to compute it: Cross Entropy in PyTorch
I updated the code to discard padded tokens (-100). The final code is this:
class compute_crossentropyloss_manual:
"""
y0 is the vector with shape (batch_size,C)
x shape is the same (batch_size), whose entries are integers from 0 to C-1
"""
def __init__(self, ignore_index=-100) -> None:
self.ignore_index=ignore_index
def __call__(self, y0, x):
loss = 0.
n_batch, n_class = y0.shape
# print(n_class)
for y1, x1 in zip(y0, x):
class_index = int(x1.item())
if class_index == self.ignore_index: # <------ I added this if-statement
continue
loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
loss = - loss/n_batch
return loss
To verify that it works fine, I tested it on a text generation task, and I computed the loss using pytorch.nn implementation and using this code.
The loss values are not identical:
using nn.CrossEntropyLoss
:
Using the code from the link above:
Am I missing something?
I tried to get the source code of nn.CrossEntropyLoss
but I wasn't able. In this link nn/functional.py at line 2955, you will see that the function points to another cross_entropy loss called torch._C._nn.cross_entropy_loss
; I can't find this function in the repo.
Edit:
I noticed that the differences appear only when I have -100
tokens in the gold.
Demo example:
y = torch.randint(1, 50, (100, 50), dtype=float)
x = torch.randint(1, 50, (100,))
x[40:] = -100
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 25.55788695847976
> 10.223154783391905
and when we don't have -100
:
x[40:] = 30 # any positive number
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 24.684453267596453
> 24.684453267596453
I solved the problem by updating the code. I discarded before the -100
tokens (the if-statement above), but I forgot to reduce the hidden_state size (which is called n_batch
in the code above). After doing that, the loss numbers are identical to the nn.CrossEntropyLoss
values. The final code:
class CrossEntropyLossManual:
"""
y0 is the vector with shape (batch_size,C)
x shape is the same (batch_size), whose entries are integers from 0 to C-1
"""
def __init__(self, ignore_index=-100) -> None:
self.ignore_index=ignore_index
def __call__(self, y0, x):
loss = 0.
n_batch, n_class = y0.shape
# print(n_class)
for y1, x1 in zip(y0, x):
class_index = int(x1.item())
if class_index == self.ignore_index:
n_batch -= 1
continue
loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
loss = - loss/n_batch
return loss