I want to extract the VGG features of a set of images and keep them in memory in a dictionary. The dictionary ends up holding 8091 tensors each of shape (1,4096), but my machine crashes with an out of memory error after about 6% of the way. Does anybody have a clue why this is happening and how to prevent it?
In fact, this seems to be triggered by the call to VGG rather than the memory space, since storing the VGG classification is sufficient to trigger the error.
Below is the simplest code I've found to reproduce the error. Once a helper function is defined:
import torch, torchvision
from tqdm import tqdm
vgg = torchvision.models.vgg16(weights='DEFAULT')
def try_and_crash(gen_data):
store_out = {}
for i in tqdm(range(8091)):
my_output = gen_data(torch.randn(1,3,224,224))
store_out[i] = my_output
return store_out
Calling it to quickly produce a large tensor doesn't cause a fuss
just_fine = try_and_crash(lambda x: torch.randn(1,4096))
but calling it to use vgg causes the machine to crash:
will_crash = try_and_crash(vgg)
The problem is that each element of the dictionary store_out[i]
also stores the gradients that led to its computation, therefore ends up being much larger than a simple 1x4096 element tensor.
Running the code with torch.no_grad()
, or equivalently with torch.set_grad_enabled(False)
solves the issue. We can test it by slightly changing the helper function
def try_and_crash_grad(gen_data, grad_enabled):
store_out = {}
for i in tqdm(range(8091)):
with torch.set_grad_enabled(grad_enabled):
my_output = gen_data(torch.randn(1,3,224,224))
store_out[i] = my_output
return store_out
Now the following works
works_fine = try_and_crash_grad(vgg, False)
while the following throws an out of memory error
crashes = try_and_crash_grad(vgg, True)