pythonpytorchbatch-normalization

How to implement batch normalization merging in python?


I have defined the model as in the code below, and I used batch normalization merging to make 3 layers into 1 linear layer.

The variables named new_weight and new_bias are the weight and bias of the newly created linear layer, respectively.

My question is: Why is the output of the following two print functions different? And where is the wrong part in the code below the batch merge comment?

import torch
import torch.nn as nn
import torch.optim as optim

learning_rate = 0.01
in_nodes = 20
internal_nodes = 8
out_nodes = 9
batch_size = 100

# model define
class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()

        self.layer1 = nn.Linear(in_nodes, internal_nodes, bias=False)
        self.layer2 = nn.BatchNorm1d(internal_nodes, affine=False)
        self.layer3 = nn.Linear(internal_nodes, out_nodes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


# optimizer and criterion
model = M()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()


# training
for batch_num in range(1000):
    model.train()
    optimizer.zero_grad()

    input = torch.randn(batch_size, in_nodes)
    target = torch.ones(batch_size, out_nodes)
    
    output = model(input)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()


# batch merge
divider = torch.sqrt(model.layer2.eps + model.layer2.running_var)

w_bn = torch.diag(torch.ones(internal_nodes) / divider)
new_weight = torch.mm(w_bn, model.layer1.weight)
new_weight = torch.mm(model.layer3.weight, new_weight)

b_bn = - model.layer2.running_mean / divider
new_bias = model.layer3.bias + torch.squeeze(torch.mm(model.layer3.weight, b_bn.reshape(-1, 1)))



input = torch.randn(batch_size, in_nodes)
print(model(input))
print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)

Solution

  • Short Answer: As far as I can tell you need a model.eval() before the line

    input = torch.randn(batch_size, in_nodes)
    

    such that the end looks like this

    ...
    model.eval()
    input = torch.randn(batch_size, in_nodes)
    test_input = torch.ones(batch_size,internal_nodes)/100
    print(model(input))
    print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)
    

    with that (I tested it) the two print-statements should output the same. It fixed the weights.

    Long Answer:

    When using Batch-Normalization according to PyTorch documentation a default momentum of 0.1 is used to compute the running_mean and running_var. The momentum defines how much the estimated statistics and how much the new observed value influence the value.

    Now when you don't set a model.eval() statement the batch_normalization computes an updated running_mean and running_var due to the momentum in line

    print(model(input))
    

    For further details and or confirmation: Related Question, PyTorch-Documentation