How can I load checkpoints from a model trained and saved with nn.DataParallel onto a model that doesn't use nn.DataParallel? I tried to remove the 'module.' from the state_dict, but I'm encountering a different error at the moment. This is the link to the ResNet-50 checkpoints.
from torchvision.models import ResNet50_Weights, resnet50
# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
# creating new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
This gives an error
RuntimeError: Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: "bn1.aux_bn.weight", "bn1.aux_bn.bias", "bn1.aux_bn.running_mean", "bn1.aux_bn.running_var", "bn1.aux_bn.num_batches_tracked", "layer1.0.bn1.aux_bn.weight", "layer1.0.bn1.aux_bn.bias", "layer1.0.bn1.aux_bn.running_mean", "layer1.0.bn1.aux_bn.running_var", "layer1.0.bn1.aux_bn.num_batches_tracked", "layer1.0.bn2.aux_bn.weight", "layer1.0.bn2.aux_bn.bias", "layer1.0.bn2.aux_bn.running_mean", "layer1.0.bn2.aux_bn.running_var", "layer1.0.bn2.aux_bn.num_batches_tracked", "layer1.0.bn3.aux_bn.weight", "layer1.0.bn3.aux_bn.bias", "layer1.0.bn3.aux_bn.running_mean", "layer1.0.bn3.aux_bn.running_var", "layer1.0.bn3.aux_bn.num_batches_tracked", "layer1.0.downsample.1.aux_bn.weight", "layer1.0.downsample.1.aux_bn.bias", "layer1.0.downsample.1.aux_bn.running_mean", "layer1.0.downsample.1.aux_bn.running_var", "layer1.0.downsample.1.aux_bn.num_batches_tracked", "layer1.1.bn1.aux_bn.weight", "layer1.1.bn1.aux_bn.bias", "layer1.1.bn1.aux_bn.running_mean", "layer1.1.bn1.aux_bn.running_var", "layer1.1.bn1.aux_bn.num_batches_tracked", "layer1.1.bn2.aux_bn.weight", "layer1.1.bn2.aux_bn.bias", "layer1.1.bn2.aux_bn.running_mean", "layer1.1.bn2.aux_bn.running_var", "layer1.1.bn2.aux_bn.num_batches_tracked", "layer1.1.bn3.aux_bn.weight", "layer1.1.bn3.aux_bn.bias",
Loading normally like this
# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict)
gives error Unexpected key(s) in state_dict: "module.conv1.weight",
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", ...
Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.bn1.aux_bn.weight", "module.bn1.aux_bn.bias", "module.bn1.aux_bn.running_mean", "module.bn1.aux_bn.running_var", "module.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.bn1.num_batches_tracked", "module.layer1.0.bn1.aux_bn.weight", "module.layer1.0.bn1.aux_bn.bias", "module.layer1.0.bn1.aux_bn.running_mean", "module.layer1.0.bn1.aux_bn.running_var", "module.layer1.0.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv2.weight", "module.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.bn2.num_batches_tracked", "module.layer1.0.bn2.aux_bn.weight", "module.layer1.0.bn2.aux_bn.bias", "module.layer1.0.bn2.aux_bn.running_mean", "module.layer1.0.bn2.aux_bn.running_var", "module.layer1.0.bn2.aux_bn.num_batches_tracked", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", "module.layer1.0.bn3.running_mean", "module.layer1.0.bn3.running_var", "module.layer1.0.bn3.num_batches_tracked", "module.layer1.0.bn3.aux_bn.weight", "module.layer1.0.bn3.aux_bn.bias", "module.layer1.0.bn3.aux_bn.running_mean", "module.layer1.0.bn3.aux_bn.running_var", "module.layer1.0.bn3.aux_bn.num_batches_tracked", "module.layer1.0.downsample.0.weight", "module.layer1.0.downsample.1.weight", "module.layer1.0.downsample.1.bias", "module.layer1.0.downsample.1.running_mean", "module.layer1.0.downsample.1.running_var", "module.
Many thanks.
You did the right thing removing the "module."
prefix but the remaining issue comes from the fact this resnet50
model was initialized with a custom normalization layer defined in aux_bn.py
as MixBatchNorm2d
. You can see the ResNet being initialized here.
This results in keys of the type "bn*.aux_bn"
.
Your code should function with the correct initialization:
checkpoint = torch.load(checkpoint_path)
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
model = resnet50(num_classes=1_000, norm_layer=MixBatchNorm2d)
model.load_state_dict(state_dict)