I was trying to create my own attention function for a project I'm working on. However, when I compared the output and weights from my code with those from torch.nn.MultiheadAttention
, I noticed that the softmax(QK^T/d_k^0.5)
is calculated incorrectly. Here is my code:
import torch
import torch.nn.functional as F
from torch.nn import MultiheadAttention
def attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k**0.5)
attn_output_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_output_weights, V)
return attn_output, attn_output_weights
embed_dim = 8
num_heads = 1
batch_size = 2
seq_len = 5
Q = torch.randn(batch_size, seq_len, embed_dim)
K = torch.randn(batch_size, seq_len, embed_dim)
V = torch.randn(batch_size, seq_len, embed_dim)
multihead_attn = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
attn_output_pytorch, attn_output_weights_pytorch = multihead_attn(Q, K, V)
attn_output_custom, attn_output_weights_custom = attention(Q, K, V)
assert torch.allclose(attn_output_custom, attn_output_pytorch, rtol=1e-6, atol=1e-8), "Attention output does not match."
assert torch.allclose(attn_output_weights_custom, attn_output_weights_pytorch, rtol=1e-6, atol=1e-8), "Attention weights do not match."
I tried changing the hyperparameters, printing each matrix, not normalizing by the d_k^0.5 factor, matching with torch.nn.functional.scaled_dot_product_attention
, and checking the shape of each tensor, but I still didn't get good results. I am primarily concerned with matching attn_output_weights_custom
and attn_output_weights_pytorch
.
Can someone spot what I might be doing wrong?
You're not using learned projections.
If you look at the state dict of the attention module, you'll see:
print(multihead_attn.state_dict().keys())
> odict_keys(['in_proj_weight', 'in_proj_bias', 'out_proj.weight', 'out_proj.bias'])
That might give you an indication of what you're missing. To reproduce pytorch's attention, you need to do the following:
import torch
import torch.nn.functional as F
from torch.nn import MultiheadAttention
import math
def attention(q, k, v,
embed_dim, num_heads,
in_proj_weight, in_proj_bias,
out_proj_weight, out_proj_bias,
batch_first=True):
# transpose if batch first
if batch_first:
q = q.transpose(1,0)
k = k.transpose(1,0)
v = v.transpose(1,0)
# get dimensions
tgt_len, bsz, embed_dim = q.shape
src_len, _, _ = k.shape
head_dim = embed_dim // num_heads
# chunk in projection weights
w_q, w_k, w_v = multihead_attn.in_proj_weight.chunk(3)
b_q, b_k, b_v = in_proj_bias.chunk(3)
# compute in projections
q = F.linear(q, w_q, b_q)
k = F.linear(k, w_k, b_k)
v = F.linear(v, w_v, b_v)
# reshape for attention
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
# get updated dimensions
src_len = k.size(1)
B, Nt, E = q.shape
# scale query
q_scaled = q * math.sqrt(1.0 / float(E))
# compute attention weights
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
# compute attention output
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# average attention weights between heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.mean(dim=1)
# if batch first, reshape output
if batch_first:
attn_output = attn_output.transpose(1,0)
return attn_output, attn_output_weights
embed_dim = 8
num_heads = 1
batch_size = 2
seq_len = 5
Q = torch.randn(batch_size, seq_len, embed_dim)
K = torch.randn(batch_size, seq_len, embed_dim)
V = torch.randn(batch_size, seq_len, embed_dim)
multihead_attn = MultiheadAttention(embed_dim=embed_dim,
num_heads=num_heads,
batch_first=True)
attn_output_pytorch, attn_output_weights_pytorch = multihead_attn(Q, K, V)
attn_output_custom, attn_output_weights_custom = attention(Q, K, V,
embed_dim,
num_heads,
multihead_attn.in_proj_weight,
multihead_attn.in_proj_bias,
multihead_attn.out_proj.weight,
multihead_attn.out_proj.bias,
batch_first=True)
assert torch.allclose(attn_output_custom, attn_output_pytorch), "Attention output does not match."
assert torch.allclose(attn_output_weights_custom, attn_output_weights_pytorch), "Attention weights do not match."
If you run the above code a bunch of times, you'll encounter a few instances where the allclose
check fails - this is because pytorch uses a compiled cuda kernel under the hood and there can be slight numeric differences. Overall, this is the attention algorithm you are looking for.
You can see the full pytorch implementation here