pytorchvision-transformer

Error loading state_dict for ViT-H-14 model in PyTorch


I'm trying to train a Vision Transformer (ViT-H-14) model using a pre-trained weight file, but I'm encountering an error when loading the state_dict. The error occurs when I load the weights manually using the following code:

def vit_h_14(weight_path="/content/vit_h_14_swag-80465313.pth"):
    pretrained_vit = torchvision.models.vit_h_14()
    pretrained_vit.load_state_dict(torch.load(weight_path))

    for parameter in pretrained_vit.parameters():
        parameter.requires_grad = False
    return pretrained_vit

The error message I get is:

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
    size mismatch for encoder.pos_embedding: copying a param with shape torch.Size([1, 1370, 1280]) from checkpoint, the shape in current model is torch.Size([1, 257, 1280]).

However, when I load the pre-trained weights directly from the PyTorch API using the following code, the model trains successfully:

def vit_h_14():
    pretrained_vit_weights = torchvision.models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
    pretrained_vit = torchvision.models.vit_h_14(weights=pretrained_vit_weights).to(device)

    for parameter in pretrained_vit.parameters():
        parameter.requires_grad = False
    return pretrained_vit

output training I'm puzzled by this difference because other pre-trained ViT models can be trained using the same approach without any issues. Here's my training code for reference:

pretrained_transform = torchvision.models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()

train_dataloader, val_dataloader, test_dataloader, class_names = create_dataloaders(
    train_dir=train_dir,
    val_dir=val_dir,
    transforms=pretrained_transform,
    test_dir=test_dir,
    batch_size=BATCH_SIZE
)

# setup model
model = vit_h_14(weight_path=weight_path)
model.heads = nn.Linear(
    in_features=model.heads.head.in_features,
    out_features=len(class_names)
).to(device)

# setup optimizer and loss function
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()

# train model
results = train(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=EPOCH,
    device=device
)

# save model
model_path = f"../results/{arch_name}/model.pth"
torch.save(model.state_dict(), model_path)

Can someone help me understand why I'm getting this error when loading the weights manually for the ViT-H-14 model, and what I can do to resolve it? I'd appreciate any insights or suggestions. Thank you in advance!


Solution

  • The short answer: Replace pretrained_vit = torchvision.models.vit_h_14() with torchvision.models.vit_h_14(image_size=518) in your vit_h_14(weight_path) function.

    The long answer

    When you load weights via the constant IMAGENET1K_SWAG_E2E_V1, as you do in your second implementation, some additional things happen as compared to the version where you try to load the weights directly/manually; in particular:

    1. On creating the constant IMAGENET1K_SWAG_E2E_V1, the weights are wrapped into a ViT_H_14_Weights(WeightsEnum) instance, which holds torchvision.models._api.Weights instances as values (see corresponding source code, line 562).
    2. On creating your actual vision transformer instance via torchvision.models.vit_h_14(), internally, the function _vision_transformer() is called, and receives, among others, the ViT_H_14_Weights(WeightsEnum) instance (see corresponding source code, line 777).
    3. The function _vision_tranformer() extracts additional parameters from the ViT_H_14_Weights(WeightsEnum) instance. Crucially, it extracts the attribute weights.meta["min_size"][0], which is then passed as "image_size" when creating the actual VisionTransformer instance (see corresponding source code, lines 321 and 331). This "min_size"/"image_size" value is 518 (see corresponding source code, line 574).

    It is this "min_size"/"image_size" parameter that prevents you from loading the dict manually/directly, as it is not set in your manual version.

    The following code fixes the problem, so that the weights can be loaded:

    import torch
    import torchvision
    
    weight_path = "/your/path/to/vit_h_14_swag-80465313.pth"  # TODO: adjust
    
    def vit_h_14_manual(weight_path):
        pretrained_vit = torchvision.models.vit_h_14(image_size=518)
        pretrained_vit.load_state_dict(torch.load(weight_path))
    
        for parameter in pretrained_vit.parameters():
            parameter.requires_grad = False
        return pretrained_vit
    
    vit_h_14_manual(weight_path)
    

    I am not sure if this is the only adjustment you need to make for your code to actually work as expected. For example, the "num_classes" parameter is overwritten in a similar fashion as the "image_size" parameter (see corresponding source code, line 319). So better proceed with caution.