Here's an example function that illustrates this problem. It's an attempt to get the Q matrix (for KV caching) of a concatenated QKV matrix (from a transformer large language model). Before the QKV projections are split into heads and then split in to q, k, v projections per head, the manual_q
matches unsplit_q
. After these operations occur, not only does manual_q
and q
differ, they differ widely. What's happening here to cause this discrepancy?
def get_q_matrix(self, x):
batch_size, seq_len, n_embd = x.size()
debug_qkv = self.query_key_value(x) # shape (batch_size, seq_len, n_embd)
unsplit_q, _, _ = debug_qkv.split(
self.n_embd, dim=-1
) # shape (batch_size, seq_len, n_embd // 3)
debug_qkv = debug_qkv.view(
batch_size, seq_len, self.n_head, 3 * self.head_size
) # shape (batch_size, seq_len, 4, 96)
q, _, _ = debug_qkv.split(
self.head_size, dim=-1
) # shape (batch_size, seq_len, 4, 96 // 3)
# Ensure correct weight and bias extraction
weight = self.query_key_value.weight
bias = self.query_key_value.bias
q_weight, k_weight, v_weight = weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = bias.chunk(3, dim=0)
manual_q = F.linear(x, q_weight, q_bias)
manual_q = manual_q # shape (batch_size, seq_len, n_embd)
assert torch.allclose(unsplit_q, manual_q) # Passes
print(torch.max(torch.abs(unsplit_q - manual_q))) # tensor(0.)
manual_q = manual_q.view(batch_size, seq_len, self.n_head, self.head_size)
print(torch.max(torch.abs(q - manual_q))) # tensor(35.6218)
assert torch.allclose(q, manual_q) # AssertionError
return manual_q
Your view operations are changing the layout of data in the tensor. This is leading to the discrepancy. Take a simple example:
Create a dummy tensor representing the packed QKV values. The tensor has batch size and sequence length of 1, and int values to easily track how values move around.
import torch
d_emb = 1
n_heads = 4
# size (1, 1, d_emb*n_heads*3)
debug_qkv = torch.arange(d_emb*n_heads*3)[None,None,:]
print(debug_qkv.shape)
> torch.Size([1, 1, 12])
print(debug_qkv)
> tensor([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]])
Now we compute unsplit_q
by splitting debug_qkv
along the final axis. We see that unsplit_q
has the values [0, 1, 2, 3]
. This makes sense. The last dim of debug_qkv
has values from 0 to 11. We split this into 3 contiguous chunks. The first chunk, unsplit_q
, has values from 0 to 3.
unsplit_q, _, _ = debug_qkv.split(n_heads, -1)
print(unsplit_q.shape)
> tensor([[[0, 1, 2, 3]]])
print(unsplit_q)
> torch.Size([1, 1, 4])
Now we look at the view operation. In your code, this is the line debug_qkv = debug_qkv.view(batch_size, seq_len, self.n_head, 3 * self.head_size)
We can see how the tensor layout has changed. The view operation fills out the n_head
dimension first, then the final dimension. The value we are looking for - [0, 1, 2, 3]
- has actually been split up.
qkv_reshaped = debug_qkv.view(1, 1, n_heads, d_emb*3)
print(qkv_reshaped.shape)
> torch.Size([1, 1, 4, 3])
print(qkv_reshaped)
> tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]]])
This propagates forward when we split q
from qkv_reshaped
q, _, _ = qkv_reshaped.split(d_emb, dim=-1)
print(q)
> tensor([[[[0],
[3],
[6],
[9]]]])
We can see clearly how the values are mixed up. unsplit_q
has contiguous values [0, 1, 2, 3]
, while q
has the reshaped values [0, 3, 6, 9]
.
The solution in this case is to reshape debug_qkv
to have the head dimension last
qkv_reshaped2 = debug_qkv.view(1, 1, d_emb*3, n_heads)
print(qkv_reshaped2)
> tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]]])
q2, _, _ = qkv_reshaped2.split(d_emb, dim=-2)
print(q2)
> tensor([[[[0, 1, 2, 3]]]])
If you read the source code for pytorch's multiheadattention, you'll find a lot of operations that look like this:
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
This operation reshapes with the head dim last because of this exact issue.