pytorchcomputer-visionfeature-extractiontransformer-modeltorchvision

How do I extract features from a torchvision VisitionTransfomer (ViT)?


In order to use features from a pretrained VisionTransformer for a downstream task, I'd like to extract features. How do I extract features for example using a vit_b_16 from torchvision? The output should be 768 dimensional features for each image.

Similar as done using CNNs, I was just trying to remove the output layer and pass the input through the remaining layers:

    from torch import nn

    from torchvision.models.vision_transformer import vit_b_16
    from torchvision.models import ViT_B_16_Weights
    
    from PIL import Image as PIL_Image

    vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    modules = list(vit.children())[:-1]
    feature_extractor = nn.Sequential(*modules)

    preprocessing = ViT_B_16_Weights.DEFAULT.transforms()

    img = PIL_Image.open("example.png")
    img = preprocessing(img)

    feature_extractor(img)

This leads however to an exception:

RuntimeError: The size of tensor a (14) must match the size of tensor b (768) at non-singleton dimension 2

Solution

  • Looking at the forward function in the source code of VisionTransformer and this helpful forum post, I managed to extract the features in the following way:

    
        from torch import nn
    
        from torchvision.models.vision_transformer import vit_b_16
        from torchvision.models import ViT_B_16_Weights
    
        from PIL import Image as PIL_Image
    
        vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    
        preprocessing = ViT_B_16_Weights.DEFAULT.transforms()
    
        img = PIL_Image.open("example.png")
        img = preprocessing(img)
    
        # Add batch dimension
        img = img.unsqueeze(0)
    
        feats = vit._process_input(img)
    
        # Expand the CLS token to the full batch
        batch_class_token = vit.class_token.expand(img.shape[0], -1, -1)
        feats = torch.cat([batch_class_token, feats], dim=1)
    
        feats = vit.encoder(feats)
    
        # We're only interested in the representation of the CLS token that we appended at position 0
        feats = feats[:, 0]
    
        print(feats.shape)
    
    

    Which correctly returns:

    torch.Size([1, 768])
    

    Edit: Depending on the downstream task, it might be better to average the features for all patches instead of taking the features from the CLS token: feats_avg = feats[:, 1:].mean(dim=1).