I want to implement learned size quantization algorithm. And I create a quante Linear layer
class QLinear(nn.Module):
def __init__(self, input_dim, out_dim, bits=8):
super(QLinear, self).__init__()
# create a tensor requires_grad=True
self.up = 2 ** bits - 1
self.down = 0
self.fc = nn.Linear(input_dim, out_dim)
weight = self.fc.weight.data
self.scale = nn.Parameter(torch.Tensor((torch.max(weight) - torch.min(weight)) / (self.up - self.down)), requires_grad=True)
self.zero_point = nn.Parameter(torch.Tensor(self.down - (torch.min(weight) / self.scale).round()), requires_grad=True)
def forward(self, x):
weight = self.fc.weight
quant_weight = (round_ste(weight / self.scale) + self.zero_point)
quant_weight = torch.clamp(quant_weight, self.down, self.up)
dequant_weight = ((quant_weight - self.zero_point) * self.scale)
self.fc.weight.data = dequant_weight
return self.fc(x)
class QNet(nn.Module):
def __init__(self):
super(QNet, self).__init__()
self.fc1 = QLinear(28 * 28, 100)
self.fc2 = QLinear(100, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = self.fc2(x)
x = F.softmax(x)
return x
when I train this network,scale's grad always return None. Why this happen and how can i solve it?
The issue is that you are passing dequant_weight
through data attribute of your parameter which ends up not being registered by autograd. A simple alternative would be to handle weight
as a nn.Parameter
and apply a linear operator manually in the forward definition directly with the computed weight dequant_weight
.
Here is a minimal example that should work:
class QLinear(nn.Module):
def __init__(self, input_dim, out_dim, bits=8):
super().__init__()
self.up = 2 ** bits - 1
self.down = 0
self.weight = nn.Parameter(torch.rand(out_dim, input_dim))
self.scale = nn.Parameter(
torch.Tensor((self.weight.max() - self.weight.min()) / (self.up - self.down)))
self.zero_point = nn.Parameter(
torch.Tensor(self.down - (self.weight.min() / self.scale).round()))
def forward(self, x):
quant_weight = (torch.round(self.weight / self.scale) + self.zero_point)
quant_weight = torch.clamp(quant_weight, self.down, self.up)
dequant_weight = ((quant_weight - self.zero_point) * self.scale)
return F.linear(x, dequant_weight)
Side notes:
nn.Parameter
requires gradient computation by default (no need to provide requires_grad=True
.
Additionally you can reformat QNet
by inheriting from nn.Sequential
to avoid boilerplate code:
class QNet(nn.Sequential):
def __init__(self):
super().__init__(nn.Flatten(),
QLinear(28 * 28, 100),
nn.ReLU(),
QLinear(100, 10),
nn.Softmax())