I'm trying to train a Vision Transformer (ViT-H-14) model using a pre-trained weight file, but I'm encountering an error when loading the state_dict. The error occurs when I load the weights manually using the following code:
def vit_h_14(weight_path="/content/vit_h_14_swag-80465313.pth"):
pretrained_vit = torchvision.models.vit_h_14()
pretrained_vit.load_state_dict(torch.load(weight_path))
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
return pretrained_vit
The error message I get is:
RuntimeError: Error(s) in loading state_dict for VisionTransformer:
size mismatch for encoder.pos_embedding: copying a param with shape torch.Size([1, 1370, 1280]) from checkpoint, the shape in current model is torch.Size([1, 257, 1280]).
However, when I load the pre-trained weights directly from the PyTorch API using the following code, the model trains successfully:
def vit_h_14():
pretrained_vit_weights = torchvision.models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
pretrained_vit = torchvision.models.vit_h_14(weights=pretrained_vit_weights).to(device)
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
return pretrained_vit
I'm puzzled by this difference because other pre-trained ViT models can be trained using the same approach without any issues. Here's my training code for reference:
pretrained_transform = torchvision.models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()
train_dataloader, val_dataloader, test_dataloader, class_names = create_dataloaders(
train_dir=train_dir,
val_dir=val_dir,
transforms=pretrained_transform,
test_dir=test_dir,
batch_size=BATCH_SIZE
)
# setup model
model = vit_h_14(weight_path=weight_path)
model.heads = nn.Linear(
in_features=model.heads.head.in_features,
out_features=len(class_names)
).to(device)
# setup optimizer and loss function
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()
# train model
results = train(
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=EPOCH,
device=device
)
# save model
model_path = f"../results/{arch_name}/model.pth"
torch.save(model.state_dict(), model_path)
Can someone help me understand why I'm getting this error when loading the weights manually for the ViT-H-14 model, and what I can do to resolve it? I'd appreciate any insights or suggestions. Thank you in advance!
The short answer: Replace pretrained_vit = torchvision.models.vit_h_14()
with torchvision.models.vit_h_14(image_size=518)
in your vit_h_14(weight_path)
function.
When you load weights via the constant IMAGENET1K_SWAG_E2E_V1
, as you do in your second implementation, some additional things happen as compared to the version where you try to load the weights directly/manually; in particular:
IMAGENET1K_SWAG_E2E_V1
, the weights are wrapped into a ViT_H_14_Weights(WeightsEnum)
instance, which holds torchvision.models._api.Weights
instances as values (see corresponding source code, line 562).torchvision.models.vit_h_14()
, internally, the function _vision_transformer()
is called, and receives, among others, the ViT_H_14_Weights(WeightsEnum)
instance (see corresponding source code, line 777)._vision_tranformer()
extracts additional parameters from the ViT_H_14_Weights(WeightsEnum)
instance. Crucially, it extracts the attribute weights.meta["min_size"][0]
, which is then passed as "image_size"
when creating the actual VisionTransformer
instance (see corresponding source code, lines 321 and 331). This "min_size"
/"image_size"
value is 518 (see corresponding source code, line 574).It is this "min_size"
/"image_size"
parameter that prevents you from loading the dict manually/directly, as it is not set in your manual version.
The following code fixes the problem, so that the weights can be loaded:
import torch
import torchvision
weight_path = "/your/path/to/vit_h_14_swag-80465313.pth" # TODO: adjust
def vit_h_14_manual(weight_path):
pretrained_vit = torchvision.models.vit_h_14(image_size=518)
pretrained_vit.load_state_dict(torch.load(weight_path))
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
return pretrained_vit
vit_h_14_manual(weight_path)
I am not sure if this is the only adjustment you need to make for your code to actually work as expected. For example, the "num_classes"
parameter is overwritten in a similar fashion as the "image_size"
parameter (see corresponding source code, line 319). So better proceed with caution.