I create a model with a multi head attention layer,
import torch
import torch.nn as nn
query = torch.randn(2, 4)
key = torch.randn(2, 4)
value = torch.randn(2, 4)
model = nn.MultiheadAttention(4, 1, bias=False)
model(query, key, value)
I attempt at matching the attention output obtained,
softmax_output = torch.softmax(((query@model.in_proj_weight[:4])@((key@model.in_proj_weight[4:8]).t()))/2, dim=1)
intermediate_output = softmax_output@(value@model.in_proj_weight[8:12])
final_output = intermediate_output@model.out_proj.weight
but the final_output
does not match the attention output
was able to match the output,
q_w = query@model.in_proj_weight[:4].t()
k_w = key@model.in_proj_weight[4:8].t()
v_w = value@model.in_proj_weight[8:12].t()
softmax_output = torch.softmax((q_w@k_w.t())/2, dim=1)
attention = softmax_output@v_w
final_output = attention@model.out_proj.weight.t()
was missing the transpose earlier