pythonmachine-learningpytorchout-of-memoryresnet

Resnet out of memory: torch.OutOfMemoryError: CUDA out of memory


I'm training an end-to-end model on a video task. I used Pytorch ResNet50 as the encoder, and the input shape is (1,seq_length,3,224,224), where seq_length is the number of frames in each video. For example, if video 1 has 1500 frames, the input shape is (1,1500,3,224,224), if video 2 has 2000 frames, the input shape is (1,2000,3,224,224). However, when I feed the input to Resnet, CUDA will run out of memory when passed through the first convolution layer in the forward pass.

I have tried:

  1. torch.cuda.empty_cache()
  2. set pin_memory=True and prefetch_factor=2 in dataloader
  3. set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
  4. reducing seq_length. This doesn't work as it significantly affects performance, the video is a sequence of data, reducing seq_length will break this structure.

Are there any hacks that can help with this issue? Any help is appreciated

Below is the full error message

Traceback (most recent call last):
  File "Path/E2E.py", line 143, in <module>
    p_classes1, phase_preds = model1(long_feature)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "Path/E2E.py", line 47, in forward
    x = self.resnet_lstm(x)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "Path/train_embedding.py", line 226, in forward
    x = self.share.forward(x)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torchvision/models/resnet.py", line 155, in forward
    out = self.bn3(out)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "Path/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 176, in forward
    return F.batch_norm(
  File "Path/lib/python3.9/site-packages/torch/nn/functional.py", line 2512, in batch_norm
    return torch.batch_norm(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 614.00 MiB. GPU 0 has a total capacity of 23.65 GiB of which 353.38 MiB is free. Including non-PyTorch memory, this process has 23.28 GiB memory in use. Of the allocated memory 22.84 GiB is allocated by PyTorch, and 3.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Here are the code of the Resnet50 model, the self.fc layer exist simply because I want to be able to load pre-trained model.

class resnet_lstm(torch.nn.Module):
    def __init__(self):
        super(resnet_lstm, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.share = torch.nn.Sequential()
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.fc = nn.Sequential(nn.Linear(2048, 512),
                                nn.ReLU(),
                                nn.Linear(512, 7))

    def forward(self, x):
        x = x.view(-1, 3, 224, 224)
        x = self.share.forward(x)
        x = x.view(1,-1, 2048)
        return x

Here's the implementation of the dataset, the file_paths is a list of paths to the images of length seq_length:

class CustomDataset(Dataset):
    def __init__(self, file_paths, file_labels, transform=None,
                 loader=pil_loader):
        self.file_paths = file_paths
        self.file_labels_phase = file_labels
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        img_names_list = self.file_paths[index]
        labels_phase = self.file_labels_phase[index]
        imgs = [self.loader(img_name) for img_name in img_names_list]
        if self.transform is not None:
            imgs = [self.transform(img) for img in imgs]

        return imgs, labels_phase, index

    def __len__(self):
        return len(self.file_paths)

class SeqSampler(RandomSampler):
    def __init__(self, data_source, seed=1):
        super().__init__(data_source)
        self.data_source = data_source
        self.seed = seed

    def __iter__(self):
        if self.seed is not None:
            random.seed(self.seed)
        
        # Generate a list of indices and shuffle them
        indices = list(range(len(self.data_source)))
        random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return len(self.idx)

train_loaders = DataLoader(
            train_dataset_80,
            batch_size=1,
            sampler=SeqSampler(train_dataset_80),
            num_workers=1,
            pin_memory=True,
        )

Here's how I feed the input, this is quite regular:

    for data,labels_phase in tqdm(train_loaders,desc=f"Epoch    {epoch+1}/{max_epochs}"):
        
        long_feature = torch.tensor(np.array(data)).to(device)
        labels_phase = np.asarray(labels_phase).squeeze()

        optimizer1.zero_grad()

        labels_phase = torch.LongTensor(labels_phase).to(device)
        
        #Allocated: 2239.18 MB
        #Cached:    2248.00 MB
        p_classes = model1(long_feature)

Here's the output of nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A6000               On  | 00000000:25:00.0 Off |                  Off |
| 30%   24C    P8              17W / 300W |     12MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      7388      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+

Solution

  • Your approach fundamentally seems to be to just treat each frame of the video as a separate image.

    When you do

    for data,labels_phase in tqdm(train_loaders,desc=f"Epoch {epoch+1}/{max_epochs}"):
    

    I don't know how your train_loaders is initialized but if it is initialized with a batch size b, when you do x = x.view(-1, 3, 224, 224) in forward(), you are treating each frame in the video as a separate image, therefore the effective batch size becomes b * seq_length which can get rather large.

    This I think is the root cause of your memory issues. One obvious fix would be to break your video into multiple chunks so that your seq_length effectively comes down. To generate one embedding per video, you can pool the embeddings of all the chunks (like by averaging them).

    Alternative approaches

    Not only is your approach memory-inefficient, it's also not using the temporal semantics of the video (because you treat a video has a huge set of separate unconnected image frames).

    So you could instead try:

    1. 3d convolutions: where the idea is to apply convolutions not only across the spatial dimensions of each frame, but also temporally across frames within the same video. This would result in better learning, but wouldn't alleviate the memory issues because you still need to load each video fully.
    2. Sequence modelling: Using sequence models like LSTMs would solve both problems (Memory wise, you don't have to strictly load each video completely when forward() gets called).