pytorchloss-functionmini-batch

pytorch loss accumulated when using mini-batch


I am new to pytorch. May I ask what is the difference between adding 'loss.item()' or not? The following 2 parts of code:

for epoch in range(epochs):
    trainingloss =0
    for i in range(0,X.size()[1], batch_size):
        indices = permutation[i:i+batch_size]
        F = model.forward(X[n])
        optimizer.zero_grad()
        criterion = loss(X,n)
        criterion.backward()
        optimizer.step()
        trainingloss += criterion.item()

and this

for epoch in range(epochs):
    for i in range(0,X.size()[1], batch_size):
        indices = permutation[i:i+batch_size]
        F = model.forward(X[n])
        optimizer.zero_grad()
        criterion = loss(X,n)
        criterion.backward()
        optimizer.step()

If anyone has any idea please help. Thank you very much.


Solution

  • Calling loss.item() allows you to take a loss variable that is detached from the computation graph that PyTorch creates (this is what .item() does for PyTorch variables).

    If you add the line trainingloss += criterion.item() at the end of each "batch loop", this will keep track of the batch loss throughout the iteration by incrementally adding the loss for each minibatch in your training set. This is necessary since you are using minibatches - the loss for each minibatch will not be equal to the loss over all the batches.

    Note: If you use PyTorch variables outside the optimization loop, e.g. in a different scope, which could happen if you call something like return loss, it is crucial that you call .item() on any PyTorch variables that are part of the computation graph (as a general rule of thumb, any outputs/loss/models that interact with PyTorch methods will likely be part of your computation graph). If not, this can cause the computation graph to not be de-allocated/deleted from Python memory, and can lead to CPU/GPU memory leaks. What you have above looks correct though!

    Also, in the future, PyTorch's DataLoader class can help you with minibatches with less boilerplate code - it can loop over your dataset such that each item you loop over is a training batch - i.e. you don't require two for loops in your optimization.

    I hope you enjoy learning/using PyTorch!