I have defined the model as in the code below, and I used batch normalization merging to make 3 layers into 1 linear layer.
The variables named new_weight and new_bias are the weight and bias of the newly created linear layer, respectively.
My question is: Why is the output of the following two print functions different? And where is the wrong part in the code below the batch merge comment?
import torch
import torch.nn as nn
import torch.optim as optim
learning_rate = 0.01
in_nodes = 20
internal_nodes = 8
out_nodes = 9
batch_size = 100
# model define
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.layer1 = nn.Linear(in_nodes, internal_nodes, bias=False)
self.layer2 = nn.BatchNorm1d(internal_nodes, affine=False)
self.layer3 = nn.Linear(internal_nodes, out_nodes)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# optimizer and criterion
model = M()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# training
for batch_num in range(1000):
model.train()
optimizer.zero_grad()
input = torch.randn(batch_size, in_nodes)
target = torch.ones(batch_size, out_nodes)
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# batch merge
divider = torch.sqrt(model.layer2.eps + model.layer2.running_var)
w_bn = torch.diag(torch.ones(internal_nodes) / divider)
new_weight = torch.mm(w_bn, model.layer1.weight)
new_weight = torch.mm(model.layer3.weight, new_weight)
b_bn = - model.layer2.running_mean / divider
new_bias = model.layer3.bias + torch.squeeze(torch.mm(model.layer3.weight, b_bn.reshape(-1, 1)))
input = torch.randn(batch_size, in_nodes)
print(model(input))
print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)
Short Answer: As far as I can tell you need a model.eval()
before the line
input = torch.randn(batch_size, in_nodes)
such that the end looks like this
...
model.eval()
input = torch.randn(batch_size, in_nodes)
test_input = torch.ones(batch_size,internal_nodes)/100
print(model(input))
print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)
with that (I tested it) the two print
-statements should output the same. It fixed the weights.
Long Answer:
When using Batch-Normalization
according to PyTorch documentation a default momentum
of 0.1 is used to compute the running_mean
and running_var
. The momentum defines how much the estimated statistics and how much the new observed value influence the value.
Now when you don't set a model.eval()
statement the batch_normalization
computes an updated running_mean
and running_var
due to the momentum in line
print(model(input))
For further details and or confirmation: Related Question, PyTorch-Documentation