nlpartificial-intelligencelarge-language-modelimage-text

How to Fine-Tune Projection Layer in CLIP Model Using LoRA?


I'm trying to fine-tune the projection layers in the CLIP model using LoRA.

I need help identifying the exact projection layers to modify for my fine-tuning and how I can apply LoRA to them.

Model loading:

import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

Model structure when printed

CLIP(
  (visual): VisionTransformer()
  (transformer): Transformer()
  (token_embedding): Embedding(49408, 512)
  (ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)

I need help identifying the exact projection layers to modify for my fine-tuning and how I can apply LoRA to them.


Solution

  • You will not see the projection layers when you print the architecture with print(model), because the projection layers are initialized with nn.Parameter() in the openai CLIP repo (unlike the huggingface implementation which uses linear layers). The code references can be found:

    You can still print the layers initialized with nn.Parameter by:

    for name, param in model.named_parameters():
        print(f'{name}: {param.shape}')
    

    Output:

    text_projection: torch.Size([512, 512])
    visual.proj: torch.Size([768, 512])
    ...
    

    The issue you face now is that nn.Parameter is not supported by peft/LoRA (explanation). You could now either modify the Clip code (using nn.Linear instead of nn.Parameter) or use the CLIP implementation of huggingface (mind the different layer names):

    from transformers import CLIPModel
    from peft import LoraConfig, get_peft_model
    
    transformers_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    
    config = LoraConfig(
        target_modules=["visual_projection", "text_projection"],
    )
    
    peft_model = get_peft_model(transformers_model, config)
    peft_model.print_trainable_parameters()
    

    Output:

    trainable params: 18,432 || all params: 151,295,745 || trainable%: 0.0122