deep-learningpytorchcomputation-graph

How Weight update in Dynamic Computation Graph of pytorch works?


How does the Weight Update works in Pytorch code of Dynamic Computation Graph when Weights are shard (=reused multiple times)

https://pytorch.org/tutorials/beginner/examples_nn/dynamic_net.html#sphx-glr-beginner-examples-nn-dynamic-net-py

import random
import torch

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
    """
    In the constructor we construct three nn.Linear instances that we will use
    in the forward pass.
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
        h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred

I want to know what happens to middle_linear weight at each backward which is used multiple times at a step


Solution

  • When you call backward (either as the function or a method on a tensor) the gradients of operands with requires_grad == True are calculated with respect to the tensor you called backward on. These gradients are accumulated in the .grad property of these operands. If the same operand A appears multiple times in the expression, you can conceptually treat them as separate entities A1, A2... for the backpropagation algorithm and just at the end sum their gradients so that A.grad = A1.grad + A2.grad + ....

    Now, strictly speaking, the answer to your question

    I want to know what happens to middle_linear weight at each backward

    is: nothing. backward does not change weights, only calculates the gradient. To change the weights you have to do an optimization step, perhaps using one of the optimizers in torch.optim. The weights are then updated according to their .grad property, so if your operand was used multiple times, it will be updated accordingly to the sum of the gradients in each of its uses.

    In other words, if your matrix element x has positive gradient when first applied and negative when used the second time, it may be that the net effects will cancel out and it will stay as it is (or change just a bit). If both applications call for x to be higher, it will raise more than if it was used just once, etc.