I have been working on implementing a custom multi-head attention class in PyTorch for a Transformer model for learning purposes. My implementation lacks any functionality, I just want to make it work for a base case. I've noticed that for causal attention (tokens can't attend to future tokens) my model seems to suffer from data leakage. I've come to that conclusion after testing the same script with the torch nn.MultiheadAttention class.
To me, it seems that the problem is in the way that I apply the mask, but I can't really find the problem. I've tested that a two dimensional masks broadcast properly to 4 dimensional tensors (which is my approach). I've verified several times that the right tokens are masked to no avail.
This is the code
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, d_model, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
self.query = nn.Linear(d_model, d_model, bias=False)
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.att_proj = nn.Linear(d_model, d_model, bias=False)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())
def forward(self, x):
q = x
k = x
v = x
B,T,C = x.shape
dk = d_model // n_heads
# linear projections
q = self.query(q)
k = self.key(k)
v = self.value(v)
# add number of heads
q = q.view(B,T,n_heads,dk).permute(0,2,1,3) # B,T,h,dk
k = k.view(B,T,n_heads,dk).permute(0,2,1,3)
v = v.view(B,T,n_heads,dk).permute(0,2,1,3)
# attention
x = q @ k.transpose(-2,-1) # B,h,T,dk @ B,h,dk,T --> B,h,T,T
x = x * dk ** -0.5 # B,h,T,T
x = x.masked_fill(self.mask, float('-inf')) # B,h,T,T
x = F.softmax(x, dim=(-1)) # B,n_h,T,T
x = x @ v # B,h,T,T @ B,T,h,dv --> B,h,T,dv
x = x.view(B,T,-1)
out = self.att_proj(x) # B,T,C
return out```
With a toy example I quickly get to Losses such as Training Loss: 2.307. Evaluation Loss: 2.278. When using torch the losses are far less ambitious Iteration 9999. Training Loss: 2.469. Evaluation Loss: 2.483. What am I missing?
This is my model implementation just in case the error is here
class Model(nn.Module):
def __init__(self, vocab_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.embedding_table = nn.Embedding(vocab_size, d_model)
self.mha = MultiHeadAttention(n_heads, d_model)
self.out = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, x, targets=None):
x = self.embedding_table(x)
B, T, C = x.shape
x = self.mha(x) # B,T,C
logits = self.out(x) # B,T,vocab_size
if targets is not None:
logits = logits.reshape(-1, logits.shape[-1])
targets = targets.reshape(-1)
loss = F.cross_entropy(input=logits, target=targets)
else:
loss = None
return logits, loss
def generate(self, n_chars, ix):
for _ in range(n_chars):
logits, loss = self(ix) # B, T, C
logits = logits[:,-1,:] # B, C -- we need to reshape to calculate probabilities
probs = F.softmax(logits, dim=-1) # B, C
next_ix = torch.multinomial(input=probs, num_samples=1)
ix = torch.cat((ix, next_ix), dim=1)
return ix```
I've tried using a different train and validation split methods, to make sure the leakage wasn't happening here. Then, I've tried several masking approaches, using tril and filling 0s with -inf or triu filling Trues with -inf. I have made sure diagonal is 1 so that only future tokens are masked
I have tentatively found the problem. I was reshaping one of the intermediate results in the wrong way
I could not do after v is calculated
x = x.view(B,T,-1)
Instead I should do
B,h,T,dv = x.shape
x = x.transpose(2,1).contiguous().view(B,T,h*dv) #B,T,C