pytorchtransformer-modelattention-modellarge-language-modelmultihead-attention

Inputs and Outputs Mismatch of Multi-head Attention Module (Tensorflow VS PyTorch)


I am trying to convert my tensorflow model for layers.MultiHeadAttention module from tf.keras to nn.MultiheadAttention from torch.nn module. Below are the snippets.

  1. Tensorflow Multi-head Attention
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

x_sfe_tf = np.random.randn(64, 345, 64)
x_te_tf = np.random.randn(64, 200, 64)

tes_mod_tf = layers.MultiHeadAttention(num_heads=2, key_dim=64)
output_tf = tes_mod_tf(x_sfe_tf, x_te_tf)

print(output_tf.shape)
  1. PyTorch Multi-head Attention
import torch
import torch.nn as nn

x_sfe_torch = torch.randn(64, 345, 64)
x_te_torch = torch.randn(64, 200, 64)

tes_mod_torch = nn.MultiheadAttention(embed_dim=64, num_heads=2)
output_torch = tes_mod_torch(x_sfe_torch, x_sfe_torch, x_te_torch)
print(output_torch.shape)

When I run the tensorflow's mha, it successfully returns (64, 345, 64). But when I run the pytorch's mha, it returns this error: AssertionError: key shape torch.Size([64, 345, 64]) does not match value shape torch.Size([64, 200, 64])

The tensorflow version can return an output with the size of x_sfe, neglecting its size difference from x_te. In the other hand, pytorch version requires that x_sfe and x_te must have the same dimension. I am confused on how actually the tensorflow's Multi-head Attention module works? What is the difference between PyTorch and what is the correct input for the PyTorch? Thanks in advance.


Solution

  • Tensorflow gets the input like '[batch_size, seq_len, embed_dim]' while Pytorch gets it like '[seq_len, batch_size, embed_dim]'. You can make this change using torch.permute() I hope it solves your issue.