I want to set some of my model frozen. Following the official docs:
with torch.no_grad():
linear = nn.Linear(1, 1)
linear.eval()
print(linear.weight.requires_grad)
But it prints True
instead of False
. If I want to set the model in eval mode, what should I do?
If you want to freeze part of your model and train the rest, you can set requires_grad
of the parameters you want to freeze to False
.
For example, if you only want to keep the convolutional part of VGG16 fixed:
model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
param.requires_grad = False
By switching the requires_grad
flags to False
, no intermediate buffers will be saved, until the computation gets to some point where one of the inputs of the operation requires the gradient.
Using the context manager torch.no_grad
is a different way to achieve that goal: in the no_grad
context, all the results of the computations will have requires_grad=False
, even if the inputs have requires_grad=True
. Notice that you won't be able to backpropagate the gradient to layers before the no_grad
. For example:
x = torch.randn(2, 2)
x.requires_grad = True
lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():
x2 = lin1(x1)
x3 = lin2(x2)
x3.sum().backward()
print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)
outputs:
(None, None, tensor([[-1.4481, -1.1789],
[-1.4481, -1.1789]]))
Here lin1.weight.requires_grad
was True, but the gradient wasn't computed because the oepration was done in the no_grad
context.
If your goal is not to finetune, but to set your model in inference mode, the most convenient way is to use the torch.no_grad
context manager. In this case you also have to set your model to evaluation mode, this is achieved by calling eval()
on the nn.Module
, for example:
model = torchvision.models.vgg16(pretrained=True)
model.eval()
This operation sets the attribute self.training
of the layers to False
, in practice this will change the behavior of operations like Dropout
or BatchNorm
that must behave differently at training and test time.