In the implementation of the SoftmaxWithLoss
layer, I noticed that the gradient in the backward pass is divided by batch_size
. Here's the code for reference:
class SoftmaxWithLoss:
def __init__(self):
self.loss = None # Variable to store the loss value
self.y = None # Variable to store the output of the softmax function
self.t = None # Variable to store the target (label) data, which is expected to be a one-hot vector
def forward(self, x, t):
self.t = t # Store the target data
self.y = softmax(x) # Apply softmax to input 'x' to get the probabilities
self.loss = cross_entropy_error(self.y, self.t) # Calculate the cross-entropy loss
return self.loss # Return the loss value
def backward(self, dout=1):
batch_size = self.t.shape[0] # Get the batch size
dx = (self.y - self.t) / batch_size # Calculate the gradient of the loss with respect to input 'x'
return dx # Return the gradient
I follow the computational graph. I think that it should just return y - t .
That depends on how the loss (cross_entropy_error
) is defined for a (mini-)batch. The most common convention[*] is that the loss for a batch is the average of losses for all items in the batch (rather than the sum, or the average over the whole dataset). So you get a constant factor of 1/batch_size
, which then also appears in the backward pass (in the computation of the last gradient, which then influences all other gradients with the same constant factor).
Effectively, changing that constant is the same as changing the learning rate. The 'mean over batch' convention arguably makes the choice of learning rate a bit more predictable, but in general the topic of how the learning rate should be chosen depending on batch size is not simple.
[*] E.g. in PyTorch, CrossEntropyLoss defaults to reduction='mean'
. Similarly in Keras, reduction="sum_over_batch_size"
.
Side note: in the computation graph in the case of a batch, you would have a copy of each activation node for every item in the batch (since activations are different for different items), and you would have a single final 'mean' reduction node (or two nodes: sum and divide).