I am trying to convert my tensorflow model for layers.MultiHeadAttention
module from tf.keras
to nn.MultiheadAttention
from torch.nn
module. Below are the snippets.
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
x_sfe_tf = np.random.randn(64, 345, 64)
x_te_tf = np.random.randn(64, 200, 64)
tes_mod_tf = layers.MultiHeadAttention(num_heads=2, key_dim=64)
output_tf = tes_mod_tf(x_sfe_tf, x_te_tf)
print(output_tf.shape)
import torch
import torch.nn as nn
x_sfe_torch = torch.randn(64, 345, 64)
x_te_torch = torch.randn(64, 200, 64)
tes_mod_torch = nn.MultiheadAttention(embed_dim=64, num_heads=2)
output_torch = tes_mod_torch(x_sfe_torch, x_sfe_torch, x_te_torch)
print(output_torch.shape)
When I run the tensorflow's mha, it successfully returns (64, 345, 64)
. But when I run the pytorch's mha, it returns this error:
AssertionError: key shape torch.Size([64, 345, 64]) does not match value shape torch.Size([64, 200, 64])
The tensorflow version can return an output with the size of x_sfe, neglecting its size difference from x_te. In the other hand, pytorch version requires that x_sfe and x_te must have the same dimension. I am confused on how actually the tensorflow's Multi-head Attention module works? What is the difference between PyTorch and what is the correct input for the PyTorch? Thanks in advance.
Tensorflow gets the input like '[batch_size, seq_len, embed_dim]' while Pytorch gets it like '[seq_len, batch_size, embed_dim]'. You can make this change using torch.permute() I hope it solves your issue.