Consider a code fragment from Crossformer:
def forward(self, queries, keys, values):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1./sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous()
I'm trying to accelerate it by replacing the naive calls with Flash Attention. For that, I did the following:
def forward(self, queries, keys, values):
# I'm not sure about the below - it's just a ChatGPT-assisted guess
# B represents the batch size.
# L is the sequence length for queries (or target sequence length).
# H is the number of attention heads.
# E is the depth (dimension) of each attention head for queries/keys.
# S is the sequence length for keys/values (or source sequence length).
# D is the depth (dimension) of each attention head for values.
B, L, H, E = queries.shape
_, S, _, D = values.shape
y = torch.nn.functional.scaled_dot_product_attention(
queries, keys, values, dropout_p=self.dropout_p if self.training else None)
y = y.contiguous()
return y
However, with the above code, I'm getting the following error:
RuntimeError: The size of tensor a (10) must match the size of tensor b (4) at
non-singleton dimension 1
The debugger shows me the following tensor sizes:
keys
: (2048, 4, 16, 32)queries
: (2048, 10, 16, 32)values
: (2048, 4, 16, 32)What am I missing in this change?
The sequence dimension must be at dimension -2
(see the documentation).
Thus you must transpose dimension 1 with dimension 2 in your case:
y = torch.nn.functional.scaled_dot_product_attention(
queries.transpose(1, 2),
keys.transpose(1, 2),
values.transpose(1, 2),
dropout_p=self.dropout_p if self.training else 0
).transpose(1, 2)
y = y.contiguous()
return y
Also remark that the dropout must be a number (0 when not applied).