pythondeep-learningpytorchquantization

torch Parameter grad return none


I want to implement learned size quantization algorithm. And I create a quante Linear layer

class QLinear(nn.Module):
    def __init__(self, input_dim, out_dim, bits=8):
        super(QLinear, self).__init__()
        # create a tensor requires_grad=True
        self.up = 2 ** bits - 1
        self.down = 0
        self.fc = nn.Linear(input_dim, out_dim)
        weight = self.fc.weight.data
        self.scale = nn.Parameter(torch.Tensor((torch.max(weight) - torch.min(weight)) / (self.up - self.down)), requires_grad=True)
        self.zero_point = nn.Parameter(torch.Tensor(self.down - (torch.min(weight) / self.scale).round()), requires_grad=True)

    def forward(self, x):
        weight = self.fc.weight
        quant_weight = (round_ste(weight / self.scale) + self.zero_point)
        quant_weight = torch.clamp(quant_weight, self.down, self.up)
        dequant_weight = ((quant_weight - self.zero_point) * self.scale)
        self.fc.weight.data = dequant_weight
        return self.fc(x)


class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()
        self.fc1 = QLinear(28 * 28, 100)
        self.fc2 = QLinear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x)
        return x

when I train this network,scale's grad always return None. Why this happen and how can i solve it?


Solution

  • The issue is that you are passing dequant_weight through data attribute of your parameter which ends up not being registered by autograd. A simple alternative would be to handle weight as a nn.Parameter and apply a linear operator manually in the forward definition directly with the computed weight dequant_weight.

    Here is a minimal example that should work:

    class QLinear(nn.Module):
        def __init__(self, input_dim, out_dim, bits=8):
            super().__init__()
            self.up = 2 ** bits - 1
            self.down = 0
    
            self.weight = nn.Parameter(torch.rand(out_dim, input_dim))
            self.scale = nn.Parameter(
                torch.Tensor((self.weight.max() - self.weight.min()) / (self.up - self.down)))
            self.zero_point = nn.Parameter(
                torch.Tensor(self.down - (self.weight.min() / self.scale).round()))
    
        def forward(self, x):
            quant_weight = (torch.round(self.weight / self.scale) + self.zero_point)
            quant_weight = torch.clamp(quant_weight, self.down, self.up)
            dequant_weight = ((quant_weight - self.zero_point) * self.scale)
            return F.linear(x, dequant_weight)
    

    Side notes: