In order to use features from a pretrained VisionTransformer for a downstream task, I'd like to extract features. How do I extract features for example using a vit_b_16 from torchvision? The output should be 768 dimensional features for each image.
Similar as done using CNNs, I was just trying to remove the output layer and pass the input through the remaining layers:
from torch import nn
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models import ViT_B_16_Weights
from PIL import Image as PIL_Image
vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
modules = list(vit.children())[:-1]
feature_extractor = nn.Sequential(*modules)
preprocessing = ViT_B_16_Weights.DEFAULT.transforms()
img = PIL_Image.open("example.png")
img = preprocessing(img)
feature_extractor(img)
This leads however to an exception:
RuntimeError: The size of tensor a (14) must match the size of tensor b (768) at non-singleton dimension 2
Looking at the forward function in the source code of VisionTransformer and this helpful forum post, I managed to extract the features in the following way:
from torch import nn
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models import ViT_B_16_Weights
from PIL import Image as PIL_Image
vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
preprocessing = ViT_B_16_Weights.DEFAULT.transforms()
img = PIL_Image.open("example.png")
img = preprocessing(img)
# Add batch dimension
img = img.unsqueeze(0)
feats = vit._process_input(img)
# Expand the CLS token to the full batch
batch_class_token = vit.class_token.expand(img.shape[0], -1, -1)
feats = torch.cat([batch_class_token, feats], dim=1)
feats = vit.encoder(feats)
# We're only interested in the representation of the CLS token that we appended at position 0
feats = feats[:, 0]
print(feats.shape)
Which correctly returns:
torch.Size([1, 768])
Edit: Depending on the downstream task, it might be better to average the features for all patches instead of taking the features from the CLS token: feats_avg = feats[:, 1:].mean(dim=1)
.