I have a LeNet-5 CNN training with a Self-Organizing Map trained on MNIST data. The training code (for brevity) is:
# SOM (flattened) weights-
# m = 40, n = 40, n = 84 (LeNet's output shape/dim)
centroids = torch.randn(m * n, dim, device = device, dtype = torch.float32)
locs = [np.array([i, j]) for i in range(m) for j in range(n)]
locations = torch.LongTensor(np.asarray(locs)).to(device)
del locs
def get_bmu_distance_squares(bmu_loc):
bmu_distance_squares = torch.sum(
input = torch.square(locations.float() - bmu_loc),
dim = 1
)
return bmu_distance_squares
distance_mat = torch.stack([get_bmu_distance_squares(loc) for loc in locations])
centroids = centroids.to(device)
num_epochs = 50
qe_train = list()
step = 1
for epoch in range(1, num_epochs + 1):
qe_epoch = 0.0
for x, y in train_loader:
x = x.to(device)
z = model(x)
# SOM training code:
batch_size = len(z)
# Compute distances from batch to (all SOM units) centroids-
dists = torch.cdist(x1 = z, x2 = centroids, p = p_norm)
# Find closest (BMU) and retrieve the gaussian correlation matrix
# for each point in the batch
# bmu_loc is BS, num points-
mindist, bmu_index = torch.min(dists, -1)
# print(f"quantization error = {mindist.mean():.4f}")
bmu_loc = locations[bmu_index]
# Compute the SOM weight update:
# Update LR
# It is a matrix of shape (BS, centroids) or, (BS, mxn) and tells
# for each input how much it will affect each (SOM unit) centroid-
bmu_distance_squares = distance_mat[bmu_index]
# Get current lr and neighbourhood radius for current step-
decay_val = scheduler(it = step, tot = int(len(train_loader) * num_epochs))
curr_alpha = (alpha * decay_val).to(device)
curr_sigma = (sigma * decay_val).to(device)
# Compute Gaussian neighbourhood function-
neighborhood_func = torch.exp(torch.neg(torch.div(bmu_distance_squares, ((2 * torch.square(curr_sigma)) + 1e-5))))
expanded_z = z.unsqueeze(dim = 1).expand(-1, grid_size, -1)
expanded_weights = centroids.unsqueeze(0).expand((batch_size, -1, -1))
delta = expanded_z - expanded_weights
lr_multiplier = curr_alpha * neighborhood_func
delta.mul_(lr_multiplier.reshape(*lr_multiplier.size(), 1).expand_as(delta))
delta = torch.mean(delta, dim = 0)
new_weights = torch.add(centroids, delta)
centroids = new_weights
# return bmu_loc, torch.mean(mindist)
# Compute quantization error los-
qe_loss = torch.mean(mindist)
qe_epoch += qe_loss.item()
# Empty accumulated gradients-
optimizer.zero_grad()
# Perform backprop-
qe_loss.backward()
# Update model trainable params-
optimizer.step()
step += 1
qe_train.append(qe_epoch / len(train_loader))
print(f"\nepoch = {epoch}, QE = {qe_epoch / len(train_loader):.4f}"
f" & SOM wts L2-norm = {torch.norm(input = centroids, p = 2).item():.4f}"
)
On trying to execute this code, I get the error:
line 252: qe_loss.backward()
Traceback (most recent call last): File "c:\some_dir\som_lenet5.py", line 252, in <module>
qe_loss.backward() File "c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\_tensor.py", line 522, in backward
torch.autograd.backward( File "c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\autograd\__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
The issue is that the gradient is kept on centroids
across iterations. This means that as early as the 2nd iteration, your computation of dists
will involve centroids
(updated at the end of the first iteration). When you back-propagate on that tensor, it will propagate back through to the 1st iteration. One way to prevent the gradient of centroids
at iteration n
from propagating through n-1
, up to iteration 1
is to detach centroids
before updating its values.