How does the Weight Update works in Pytorch code of Dynamic Computation Graph when Weights are shard (=reused multiple times)
import random
import torch
class DynamicNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
"""
In the constructor we construct three nn.Linear instances that we will use
in the forward pass.
"""
super(DynamicNet, self).__init__()
self.input_linear = torch.nn.Linear(D_in, H)
self.middle_linear = torch.nn.Linear(H, H)
self.output_linear = torch.nn.Linear(H, D_out)
def forward(self, x):
"""
For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
and reuse the middle_linear Module that many times to compute hidden layer
representations.
Since each forward pass builds a dynamic computation graph, we can use normal
Python control-flow operators like loops or conditional statements when
defining the forward pass of the model.
Here we also see that it is perfectly safe to reuse the same Module many
times when defining a computational graph. This is a big improvement from Lua
Torch, where each Module could be used only once.
"""
h_relu = self.input_linear(x).clamp(min=0)
for _ in range(random.randint(0, 3)):
h_relu = self.middle_linear(h_relu).clamp(min=0)
y_pred = self.output_linear(h_relu)
return y_pred
I want to know what happens to middle_linear
weight at each backward which is used multiple times at a step
When you call backward
(either as the function or a method on a tensor) the gradients of operands with requires_grad == True
are calculated with respect to the tensor you called backward
on. These gradients are accumulated in the .grad
property of these operands. If the same operand A
appears multiple times in the expression, you can conceptually treat them as separate entities A1
, A2
... for the backpropagation algorithm and just at the end sum their gradients so that A.grad = A1.grad + A2.grad + ...
.
Now, strictly speaking, the answer to your question
I want to know what happens to middle_linear weight at each backward
is: nothing. backward
does not change weights, only calculates the gradient. To change the weights you have to do an optimization step, perhaps using one of the optimizers in torch.optim
. The weights are then updated according to their .grad
property, so if your operand was used multiple times, it will be updated accordingly to the sum of the gradients in each of its uses.
In other words, if your matrix element x
has positive gradient when first applied and negative when used the second time, it may be that the net effects will cancel out and it will stay as it is (or change just a bit). If both applications call for x
to be higher, it will raise more than if it was used just once, etc.