I am trying to do a classification task but due to reasons I needed to delete the softmax and replace the loss module from cross entropy to MSE now to create a one hot tensor for the labels (target) I do the following:
labels_onehot = nn.functional.one_hot(labels, num_classes=10).float()
but when i try to calculate the loss an exception is thrown
Cell In[13], line 121
116 print("Labels one-hot shape:", labels_onehot.shape)
118 loss = criterion(outputs, labels_onehot)
--> 121 loss = criterion(outputs, labels)
122 loss.backward()
123 optimizer.step()
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/loss.py:535, in MSELoss.forward(self, input, target)
534 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 535 return F.mse_loss(input, target, reduction=self.reduction)
File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:3328, in mse_loss(input, target, size_average, reduce, reduction)
3325 if size_average is not None or reduce is not None:
3326 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3328 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3329 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
File /opt/conda/lib/python3.10/site-packages/torch/functional.py:73, in broadcast_tensors(*tensors)
71 if has_torch_function(tensors):
72 return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 73 return _VF.broadcast_tensors(tensors)
RuntimeError: The size of tensor a (10) must match the size of tensor b (64) at non-singleton dimension 1 ```
I tried printing the shapes and they both were of the same shape and I cannot see why the exceptions were thrown.
The problem might be in this line:
loss = criterion(outputs, labels)
You should replace labels
with labels_onehot
if you are using one-hot encoding for your labels:
loss = criterion(outputs, labels_onehot)