I ran into this weird behavior when trying to "manually" optimize a network's parameters via SGD. When attempting to update the model's parameters using the following way, it works just fine:
for _ in trange(epochs):
for x, y in train_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
loss = F.cross_entropy(m(x), y)
grad = torch.autograd.grad(loss, m.parameters())
with torch.no_grad():
for p, g in zip(m.parameters(), grad):
p -= 0.1 * g
However, doing the following throws off the model completely:
for _ in trange(epochs):
for x, y in train_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
loss = F.cross_entropy(m(x), y)
loss.backward()
with torch.no_grad():
for p in m.parameters():
p -= 0.1 * p.grad
But to me, both methods should be equivalent. And upon further inspection, when comparing the values of g
from grad
with the values of p.grad
from m.paramters()
, it turned out that the gradient values are not the same! I also tried removing with torch.no_grad():
and doing the following, but it didn't work either:
for p in m.parameters():
p.data -= 0.1 * p.grad
Can somebody please explain why is this happening? Shouldn't the gradients in both methods have the same values (keeping in mind that both models m
are identical)?
REPRODUCIBLE EXAMPLE:
Ensure reproducibility:
device = torch.device('cuda')
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.empty_cache()
Load the data:
T = transforms.ToTensor()
train_data = datasets.MNIST(root='data', transform=T, download=True)
test_data = datasets.MNIST(root='data', transform=T, train=False, download=True)
BS = 300
epochs = 5
LR = 0.1
train_loader = DataLoader(train_data, batch_size=BS, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=1000, pin_memory=True)
Define the model to be optimized:
class Model(nn.Module):
def __init__(self, out_dims):
super().__init__()
self.conv1 = nn.Conv2d(1, out_dims, 3, stride=3, padding=1)
self.conv2 = nn.Sequential(nn.Conv2d(out_dims, out_dims * 2, 3), nn.BatchNorm2d(out_dims * 2), nn.ReLU())
self.conv3 = nn.Sequential(nn.Conv2d(out_dims * 2, out_dims * 4, 4, stride=2, padding=1), nn.BatchNorm2d(out_dims * 4), nn.ReLU(), nn.Flatten())
self.fc = nn.Linear(out_dims * 4 * 16, 10)
def forward(self, x):
return nn.Sequential(*tuple(self.children()))(x)
m1 = Model(5).to(device)
m2 = deepcopy(m1) # "m2.load_state_dict(m1.state_dict())" doesn't work either
Training and evaluation:
# M1's training:
for _ in trange(epochs):
for x, y in train_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
loss = F.cross_entropy(m1(x), y)
grad = torch.autograd.grad(loss, m1.parameters())
with torch.no_grad():
for p, g in zip(m1.parameters(), grad):
p -= LR * g
# M1's evaluation:
m1.eval()
acc1 = []
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
_, pred = m1(x).max(1)
acc1.append(metric(pred, y).item())
print(f'Accuracy: {np.mean(acc1) * 100:.4}%')
# M2's training:
for _ in trange(epochs):
for x, y in train_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
loss = F.cross_entropy(m2(x), y)
loss.backward()
with torch.no_grad():
for p in m2.parameters():
p -= LR * p.grad
# M2's evaluation:
m2.eval()
acc2 = []
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
_, pred = m2(x).max(1)
acc2.append(metric(pred, y).item())
print(f'Accuracy: {np.mean(acc2) * 100:.4}%')
It took me a while to figure out, but the problem was in loss.backward()
. Unlike autograd.grad()
which computes and returns the gradients, the inplace backward()
computes and accumulates the gradients of participating nodes in the computation graph. In other words, the two will have the same effect when used to back-prop once, but every repetition of backward()
will add the currently computed gradients to all previous ones (hence the divergence). Resetting the gradients using model.zero_grad()
fixes stuff.