I am trying to train a machine learning model where the loss function is binary cross entropy, because of gpu limitations i can only do batch size of 4 and i'm having lot of spikes in the loss graph. So I'm thinking to back-propagate after some predefined batch size(>4). So it's like i'll do 10 iterations of batch size 4 store the losses, after 10th iteration add the losses and back-propagate. will it be similar to batch size of 40.
TL;DR
f(a+b) = f(a)+f(b) is it true for binary cross entropy?
f(a+b) = f(a) + f(b) doesn't seem to be what you're after. This would imply that BCELoss is additive which it clearly isn't. I think what you really care about is if for some index i
# false
f(x, y) == f(x[:i], y[:i]) + f([i:], y[i:])
is true?
The short answer is no, because you're missing some scale factors. What you probably want is the following identity
# true
f(x, y) == (i / b) * f(x[:i], y[:i]) + (1.0 - i / b) * f(x[i:], y[i:])
where b
is the total batch size.
This identity is used as motivation for the gradient accumulation method (see below). Also, this identity applies to any objective function which returns an average loss across each batch element, not just BCE.
Caveat/Pitfall: Keep in mind that batch norm will not behave exactly the same when using this approach since it updates its internal statistics based on batch size during the forward pass.
We can actually do a little better memory-wise than just computing the loss as a sum followed by backpropagation. Instead we can compute the gradient of each component in the equivalent sum individually and allow the gradients to accumulate. To better explain I'll give some examples of equivalent operations
Consider the following model
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__()
num_outputs = 5
# assume input shape is 10x10
self.conv_layer = nn.Conv2d(3, 10, 3, 1, 1)
self.fc_layer = nn.Linear(10*5*5, num_outputs)
def forward(self, x):
x = self.conv_layer(x)
x = F.max_pool2d(x, 2, 2, 0, 1, False, False)
x = F.relu(x)
x = self.fc_layer(x.flatten(start_dim=1))
x = torch.sigmoid(x) # or omit this and use BCEWithLogitsLoss instead of BCELoss
return x
# to ensure same results for this example
torch.manual_seed(0)
model = MyModel()
# the examples will work as long as the objective averages across batch elements
objective = nn.BCELoss()
# doesn't matter what type of optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
and lets say our data and targets for a single batch are
torch.manual_seed(1) # to ensure same results for this example
batch_size = 32
input_data = torch.randn((batch_size, 3, 10, 10))
targets = torch.randint(0, 1, (batch_size, 20)).float()
The body of our training loop for an entire batch may look something like this
# entire batch
output = model(input_data)
loss = objective(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value = loss.item()
print("Loss value: ", loss_value)
print("Model checksum: ", sum([p.sum().item() for p in model.parameters()]))
We could have computed this using the sum of multiple loss functions using
# This is simpler if the sub-batch size is a factor of batch_size
sub_batch_size = 4
assert (batch_size % sub_batch_size == 0)
# for this to work properly the batch_size must be divisible by sub_batch_size
num_sub_batches = batch_size // sub_batch_size
loss = 0
for sub_batch_idx in range(num_sub_batches):
start_idx = sub_batch_size * sub_batch_idx
end_idx = start_idx + sub_batch_size
sub_input = input_data[start_idx:end_idx]
sub_targets = targets[start_idx:end_idx]
sub_output = model(sub_input)
# add loss component for sub_batch
loss = loss + objective(sub_output, sub_targets) / num_sub_batches
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value = loss.item()
print("Loss value: ", loss_value)
print("Model checksum: ", sum([p.sum().item() for p in model.parameters()]))
The problem with the previous approach is that in order to apply back-propagation, pytorch needs to store intermediate results of layers in memory for every sub-batch. This ends up requiring a relatively large amount of memory and you may still run into memory consumption issues.
To alleviate this problem, instead of computing a single loss and performing back-propagation once, we could perform gradient accumulation. This gives equivalent results of the previous version. The difference here is that we instead perform a backward pass on each component of
the loss, only stepping the optimizer once all of them have been backpropagated. This way the computation graph is cleared after each sub-batch which will help with memory usage. Note that this works because .backward()
actually accumulates (adds) the newly computed gradients to the existing .grad
member of each model parameter. This is why optimizer.zero_grad()
must be called only once, before the loop, and not during or after.
# This is simpler if the sub-batch size is a factor of batch_size
sub_batch_size = 4
assert (batch_size % sub_batch_size == 0)
# for this to work properly the batch_size must be divisible by sub_batch_size
num_sub_batches = batch_size // sub_batch_size
# Important! zero the gradients before the loop
optimizer.zero_grad()
loss_value = 0.0
for sub_batch_idx in range(num_sub_batches):
start_idx = sub_batch_size * sub_batch_idx
end_idx = start_idx + sub_batch_size
sub_input = input_data[start_idx:end_idx]
sub_targets = targets[start_idx:end_idx]
sub_output = model(sub_input)
# compute loss component for sub_batch
sub_loss = objective(sub_output, sub_targets) / num_sub_batches
# accumulate gradients
sub_loss.backward()
loss_value += sub_loss.item()
optimizer.step()
print("Loss value: ", loss_value)
print("Model checksum: ", sum([p.sum().item() for p in model.parameters()]))