The problem
The similarity scores are almost the same for texts that describe both a photo of a cat and a dog (the photo is of a cat).
Cat similarity: tensor([[-3.5724]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[-3.4155]], grad_fn=<MulBackward0>)
The code for CLIP model
The code is based on the checkpoint of openai/clip-vit-base-patch32. The encode_text function takes a raw input and turns it into embeddings later fed into the forward method. I'm certain that the layers' names and sizes are correct, as the checkpoint fits the model without errors due to missing or unexpected layers.
class CLIP(nn.Module):
def __init__(self, project_dim: int = 768, embed_dim: int = 512):
super(CLIP, self).__init__()
self.vision_model = ImageEncoder(project_dim = project_dim)
self.text_model = TextEncoder(embed_dim = embed_dim)
self.tokenizer = TorchTokenizer()
self.logit_scale = nn.Parameter(torch.ones([]) * 0.7)
self.visual_projection = nn.Linear(project_dim, embed_dim, bias = False)
self.text_projection = nn.Linear(embed_dim, embed_dim, bias = False)
self.vision_model.eval()
self.text_model.eval()
def forward(self, image: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
" Compute the relationship between image and text "
# get fixed size to comply with the checkpoint position_embeddings nn.Embedding(50, embed_dim)
image = Resize(size=(224, 224))(image)
image_features = self.vision_model(image)
# projections
text_features = self.text_projection(text_embed)
image_features = self.visual_projection(image_features)
# normalization
text_features = F.normalize(text_features, dim = -1)
image_features = F.normalize(image_features, dim = -1)
logits = self.logit_scale.exp() * (image_features @ text_features.t())
return logits
def encode_text(self, input_ids, attention_mask = None):
""" Tokenize (if needed) and encode texts, returning embeddings and mask. Function for ConditionalPromptNorm """
# tokenize strings if raw text passed
if attention_mask is None:
input_ids, attention_mask = self.tokenizer.tokenize(input_ids)
# ensure batch dim
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
with torch.no_grad():
text_emb = self.text_model(input_ids.long(), attention_mask)
return text_emb
The code for the text encoder
I have checked that getting the EOS token does work correctly. Also, the types of layers, like nn.Embedding and nn.Parameter are correct for each layer as it would conflict with the checkpoint if it weren't the same type.
class TextEncoder(nn.Module):
def __init__(self, embed_dim: int = 512):
super(TextEncoder, self).__init__()
vocab_size = 49408
self.embeddings = nn.Module()
self.embeddings.token_embedding = nn.Embedding(vocab_size, embed_dim)
# tokenizer's context_length must be set to 77 tokens
self.embeddings.position_embedding = nn.Embedding(77, embed_dim) # 77 = context length
self.encoder = Encoder(embed_size = embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(self, text: torch.Tensor, attention_mask: torch.Tensor):
x = self.embeddings.token_embedding(text.long())
# seq_length
positions = torch.arange(x.size(1))
pos_embed = self.embeddings.position_embedding(positions)
x += pos_embed.to(x.dtype).to(x.device)
# obtain text embeddings
x = x.permute(1, 0, 2)
x = self.encoder(x, attention_mask)
x = x.permute(1, 0, 2)
# ensure batch dim
if x.dim() == 2: x = x.unsqueeze(0)
if attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0)
# for each batch, get the last token (eos)
x = x[torch.arange(x.size(0)), text.argmax(dim = -1)]
return self.final_layer_norm(x)
The attention class is from https://github.com/openai/CLIP/blob/main/clip/model.py#L58 with a slight modification to allow self and pooled attention (x and x[:1]).
The Encoder
I have checked that the tokenizer code works correctly. The MLP is the same as in the CLIP original code. Two linear layers with a ratio of 4 and a GELU in the middle.
class EncoderLayer(nn.Module):
def __init__(self, embed_size: int = 768, ratio: int = 4, num_heads: int = 8):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_size)
self.layer_norm2 = nn.LayerNorm(embed_size)
self.mlp = MLP(embed_size = embed_size, ratio = ratio)
self.self_attn = AttentionPool2d(num_heads = num_heads, embed_dim = embed_size)
def forward(self, x: torch.Tensor, src_pad_key = None):
x = self.layer_norm1(x)
if src_pad_key is not None: attn_out = self.self_attn(x, src_pad_key = src_pad_key, use_self_attention = True)
else: attn_out = self.self_attn(x)
# normalize and apply residual connections
x += attn_out
x = self.layer_norm2(x)
x += self.mlp(x)
return x
class Encoder(nn.Module):
def __init__(self, embed_size = 768):
super().__init__()
self.layers = nn.ModuleList([EncoderLayer(embed_size = embed_size) for _ in range(12)])
def forward(self, x: torch.Tensor, attention_mask = None):
if attention_mask is not None:
src_key_mask = attention_mask == 0
if src_key_mask.dim() == 1: src_key_mask = src_key_mask.unsqueeze(0)
for layer in self.layers:
x = layer(x, src_key_mask)
else:
for layer in self.layers:
x = layer(x)
return x
The issue was in the EncoderLayer where the residual calculations were done wrong. The correct way of calculating:
def forward(self, x: torch.Tensor, src_pad_key = None):
residual = x
x = self.layer_norm1(x)
if src_pad_key is not None: x = self.self_attn(x, src_pad_key = src_pad_key, use_self_attention = True)
else: x = self.self_attn(x)
# normalize and apply residual connections
x += residual
residual = x
x = self.layer_norm2(x)
x = self.mlp(x)
x += residual
return x
Another change was that we must always use self attention (instead of pooled attention) as otherwise the calculations won't work with the image encoder. [query = x]
The results look like this:
Cat similarity: tensor([[25.4132]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[21.8544]], grad_fn=<MulBackward0>)
cosine cat/dog: 0.8438754677772522