machine-learningmathpytorchbackpropagationbatchsize

is binary cross entropy an additive function?


I am trying to train a machine learning model where the loss function is binary cross entropy, because of gpu limitations i can only do batch size of 4 and i'm having lot of spikes in the loss graph. So I'm thinking to back-propagate after some predefined batch size(>4). So it's like i'll do 10 iterations of batch size 4 store the losses, after 10th iteration add the losses and back-propagate. will it be similar to batch size of 40.

TL;DR

f(a+b) = f(a)+f(b) is it true for binary cross entropy?


Solution

  • f(a+b) = f(a) + f(b) doesn't seem to be what you're after. This would imply that BCELoss is additive which it clearly isn't. I think what you really care about is if for some index i

    # false
    f(x, y) == f(x[:i], y[:i]) + f([i:], y[i:])
    

    is true?

    The short answer is no, because you're missing some scale factors. What you probably want is the following identity

    # true
    f(x, y) == (i / b) * f(x[:i], y[:i]) + (1.0 - i / b) * f(x[i:], y[i:])
    

    where b is the total batch size.

    This identity is used as motivation for the gradient accumulation method (see below). Also, this identity applies to any objective function which returns an average loss across each batch element, not just BCE.


    Caveat/Pitfall: Keep in mind that batch norm will not behave exactly the same when using this approach since it updates its internal statistics based on batch size during the forward pass.


    We can actually do a little better memory-wise than just computing the loss as a sum followed by backpropagation. Instead we can compute the gradient of each component in the equivalent sum individually and allow the gradients to accumulate. To better explain I'll give some examples of equivalent operations

    Consider the following model

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
            num_outputs = 5
            # assume input shape is 10x10
            self.conv_layer = nn.Conv2d(3, 10, 3, 1, 1)
            self.fc_layer = nn.Linear(10*5*5, num_outputs)
    
        def forward(self, x):
            x = self.conv_layer(x)
            x = F.max_pool2d(x, 2, 2, 0, 1, False, False)
            x = F.relu(x)
            x = self.fc_layer(x.flatten(start_dim=1))
            x = torch.sigmoid(x)   # or omit this and use BCEWithLogitsLoss instead of BCELoss
            return x
    
    # to ensure same results for this example
    torch.manual_seed(0)
    model = MyModel()
    # the examples will work as long as the objective averages across batch elements
    objective = nn.BCELoss()
    # doesn't matter what type of optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    

    and lets say our data and targets for a single batch are

    torch.manual_seed(1)    # to ensure same results for this example
    batch_size = 32
    input_data = torch.randn((batch_size, 3, 10, 10))
    targets = torch.randint(0, 1, (batch_size, 20)).float()
    

    Full batch

    The body of our training loop for an entire batch may look something like this

    # entire batch
    output = model(input_data)
    loss = objective(output, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_value = loss.item()
    
    print("Loss value: ", loss_value)
    print("Model checksum: ", sum([p.sum().item() for p in model.parameters()]))
    

    Weighted sum of loss on sub-batches

    We could have computed this using the sum of multiple loss functions using

    # This is simpler if the sub-batch size is a factor of batch_size
    sub_batch_size = 4
    assert (batch_size % sub_batch_size == 0)
    
    # for this to work properly the batch_size must be divisible by sub_batch_size
    num_sub_batches = batch_size // sub_batch_size
    
    loss = 0
    for sub_batch_idx in range(num_sub_batches):
        start_idx = sub_batch_size * sub_batch_idx
        end_idx = start_idx + sub_batch_size
        sub_input = input_data[start_idx:end_idx]
        sub_targets = targets[start_idx:end_idx]
        sub_output = model(sub_input)
        # add loss component for sub_batch
        loss = loss + objective(sub_output, sub_targets) / num_sub_batches
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    loss_value = loss.item()
    
    print("Loss value: ", loss_value)
    print("Model checksum: ", sum([p.sum().item() for p in model.parameters()]))
    

    Gradient accumulation

    The problem with the previous approach is that in order to apply back-propagation, pytorch needs to store intermediate results of layers in memory for every sub-batch. This ends up requiring a relatively large amount of memory and you may still run into memory consumption issues.

    To alleviate this problem, instead of computing a single loss and performing back-propagation once, we could perform gradient accumulation. This gives equivalent results of the previous version. The difference here is that we instead perform a backward pass on each component of the loss, only stepping the optimizer once all of them have been backpropagated. This way the computation graph is cleared after each sub-batch which will help with memory usage. Note that this works because .backward() actually accumulates (adds) the newly computed gradients to the existing .grad member of each model parameter. This is why optimizer.zero_grad() must be called only once, before the loop, and not during or after.

    # This is simpler if the sub-batch size is a factor of batch_size
    sub_batch_size = 4
    assert (batch_size % sub_batch_size == 0)
    
    # for this to work properly the batch_size must be divisible by sub_batch_size
    num_sub_batches = batch_size // sub_batch_size
    
    # Important! zero the gradients before the loop
    optimizer.zero_grad()
    loss_value = 0.0
    for sub_batch_idx in range(num_sub_batches):
        start_idx = sub_batch_size * sub_batch_idx
        end_idx = start_idx + sub_batch_size
        sub_input = input_data[start_idx:end_idx]
        sub_targets = targets[start_idx:end_idx]
        sub_output = model(sub_input)
        # compute loss component for sub_batch
        sub_loss = objective(sub_output, sub_targets) / num_sub_batches
        # accumulate gradients
        sub_loss.backward()
        loss_value += sub_loss.item()
    optimizer.step()
    
    print("Loss value: ", loss_value)
    print("Model checksum: ", sum([p.sum().item() for p in model.parameters()]))