pythonmemory-leakspytorchbackpropagationmemory-profiling

Pytorch model training CPU Memory leak issue


When I trained my pytorch model on GPU device,my python script was killed out of blue.Dives into OS log files , and I find script was killed by OOM killer because my CPU ran out of memory.It’s very strange that I trained my model on GPU device but I ran out of my CPU memory. Snapshot of OOM killer log file enter image description here

In order to debug this issue,I install python memory profiler. Viewing log file from memory profiler, I find when column wise -= operation occurred, my CPU memory gradually increased until OOM killer killed my program. Snapshot of Python memory profiler enter image description here It’s very strange, I try many ways to solve this issue.Finally, I found before assignment operation,I detach Tensor first.Amazingly,it solves this issue.But I don’t understand clearly why it works.Here is my original function code.

def GeneralizedNabla(self, image):
        pad_size = 2
        affinity = torch.zeros(image.shape[0], self.window_size**2, self.h, self.w).to(self.device)
        h = self.h+pad_size
        w = self.w+pad_size
        #pad = nn.ZeroPad2d(pad_size)
        image_pad = self.pad(image)
        for i in range(0, self.window_size**2):
            affinity[:, i, :, :] = image[:, :, :].detach()  # initialization
            dy = int(i/5)-2
            dx = int(i % 5)-2
            h_start = pad_size+dy
            h_end = h+dy  # if 0 <= dy else h+dy
            w_start = pad_size+dx
            w_end = w+dx  # if 0 <= dx else w+dx
            affinity[:, i, :, :] -= image_pad[:, h_start:h_end, w_start:w_end].detach()
        self.Nabla=affinity
        return

If everyone has any ideas,I will appreciate very much, thank you.


Solution

  • Previously when you did not use the .detach() on your tensor, you were also accumulating the computation graph as well and as you went on, you kept acumulating more and more until you ended up exuasting your memory to the point it crashed.
    When you do a detach(), you are effectively getting the data without the previously entangled history thats needed for computing the gradients.