pythonpytorchconv-neural-network

LeNet5 & Self-Organizing Maps - RuntimeError: Trying to backward through the graph a second time - PyTorch


I have a LeNet-5 CNN training with a Self-Organizing Map trained on MNIST data. The training code (for brevity) is:

# SOM (flattened) weights-
# m = 40, n = 40, n = 84 (LeNet's output shape/dim)
centroids = torch.randn(m * n, dim, device = device, dtype = torch.float32)

locs = [np.array([i, j]) for i in range(m) for j in range(n)]
locations = torch.LongTensor(np.asarray(locs)).to(device)
del locs

def get_bmu_distance_squares(bmu_loc):
    bmu_distance_squares = torch.sum(
        input = torch.square(locations.float() - bmu_loc),
        dim = 1
    )
    return bmu_distance_squares

distance_mat = torch.stack([get_bmu_distance_squares(loc) for loc in locations])

centroids = centroids.to(device)

num_epochs = 50
qe_train = list()

step = 1


for epoch in range(1, num_epochs + 1):
    qe_epoch = 0.0
    for x, y in train_loader:
        x = x.to(device)
        z = model(x)

        # SOM training code:

        batch_size = len(z)

        # Compute distances from batch to (all SOM units) centroids-
        dists = torch.cdist(x1 = z, x2 = centroids, p = p_norm)

        # Find closest (BMU) and retrieve the gaussian correlation matrix
        # for each point in the batch
        # bmu_loc is BS, num points-
        mindist, bmu_index = torch.min(dists, -1)
        # print(f"quantization error = {mindist.mean():.4f}")

        bmu_loc = locations[bmu_index]


        # Compute the SOM weight update:

        # Update LR
        # It is a matrix of shape (BS, centroids) or, (BS, mxn) and tells
        # for each input how much it will affect each (SOM unit) centroid-
        bmu_distance_squares = distance_mat[bmu_index]

        # Get current lr and neighbourhood radius for current step-
        decay_val = scheduler(it = step, tot =  int(len(train_loader) * num_epochs))
        curr_alpha = (alpha * decay_val).to(device)
        curr_sigma = (sigma * decay_val).to(device)

        # Compute Gaussian neighbourhood function-
        neighborhood_func = torch.exp(torch.neg(torch.div(bmu_distance_squares, ((2 * torch.square(curr_sigma)) + 1e-5))))

        expanded_z = z.unsqueeze(dim = 1).expand(-1, grid_size, -1)
        expanded_weights = centroids.unsqueeze(0).expand((batch_size, -1, -1))

        delta = expanded_z - expanded_weights
        lr_multiplier = curr_alpha * neighborhood_func

        delta.mul_(lr_multiplier.reshape(*lr_multiplier.size(), 1).expand_as(delta))
        delta = torch.mean(delta, dim = 0)
        new_weights = torch.add(centroids, delta)
        centroids = new_weights

        # return bmu_loc, torch.mean(mindist)

        # Compute quantization error los-
        qe_loss = torch.mean(mindist)
        qe_epoch += qe_loss.item()

        # Empty accumulated gradients-
        optimizer.zero_grad()

        # Perform backprop-
        qe_loss.backward()

        # Update model trainable params-
        optimizer.step()
         
        step += 1


    qe_train.append(qe_epoch / len(train_loader))
    print(f"\nepoch = {epoch}, QE = {qe_epoch / len(train_loader):.4f}"
        f" & SOM wts L2-norm = {torch.norm(input = centroids, p = 2).item():.4f}"
    )

On trying to execute this code, I get the error:

line 252: qe_loss.backward()

Traceback (most recent call last):   File "c:\some_dir\som_lenet5.py", line 252, in <module>
    qe_loss.backward()   File "c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\_tensor.py", line 522, in backward
    torch.autograd.backward(   File "c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\autograd\__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Solution

  • The issue is that the gradient is kept on centroids across iterations. This means that as early as the 2nd iteration, your computation of dists will involve centroids (updated at the end of the first iteration). When you back-propagate on that tensor, it will propagate back through to the 1st iteration. One way to prevent the gradient of centroids at iteration n from propagating through n-1, up to iteration 1 is to detach centroids before updating its values.