pytorchvgg-net

Calling VGG many times causes an out of memory error


I want to extract the VGG features of a set of images and keep them in memory in a dictionary. The dictionary ends up holding 8091 tensors each of shape (1,4096), but my machine crashes with an out of memory error after about 6% of the way. Does anybody have a clue why this is happening and how to prevent it?

In fact, this seems to be triggered by the call to VGG rather than the memory space, since storing the VGG classification is sufficient to trigger the error.

Below is the simplest code I've found to reproduce the error. Once a helper function is defined:

import torch, torchvision
from tqdm import tqdm
vgg = torchvision.models.vgg16(weights='DEFAULT')

def try_and_crash(gen_data):
    store_out = {}
    for i in tqdm(range(8091)):
        my_output = gen_data(torch.randn(1,3,224,224))
        store_out[i] = my_output
    return store_out

Calling it to quickly produce a large tensor doesn't cause a fuss

just_fine = try_and_crash(lambda x: torch.randn(1,4096))

but calling it to use vgg causes the machine to crash:

will_crash = try_and_crash(vgg)

Solution

  • The problem is that each element of the dictionary store_out[i] also stores the gradients that led to its computation, therefore ends up being much larger than a simple 1x4096 element tensor.

    Running the code with torch.no_grad(), or equivalently with torch.set_grad_enabled(False) solves the issue. We can test it by slightly changing the helper function

    def try_and_crash_grad(gen_data, grad_enabled):
        store_out = {}
        for i in tqdm(range(8091)):
            with torch.set_grad_enabled(grad_enabled):
                my_output = gen_data(torch.randn(1,3,224,224))
            store_out[i] = my_output
        return store_out
    

    Now the following works

    works_fine = try_and_crash_grad(vgg, False)
    

    while the following throws an out of memory error

    crashes = try_and_crash_grad(vgg, True)