pythondebuggingpytorchtransformer-modelattention-model

PyTorch Linear operations vary widely after reshaping


Here's an example function that illustrates this problem. It's an attempt to get the Q matrix (for KV caching) of a concatenated QKV matrix (from a transformer large language model). Before the QKV projections are split into heads and then split in to q, k, v projections per head, the manual_q matches unsplit_q. After these operations occur, not only does manual_q and q differ, they differ widely. What's happening here to cause this discrepancy?

    def get_q_matrix(self, x):
        batch_size, seq_len, n_embd = x.size()

        debug_qkv = self.query_key_value(x)  # shape (batch_size, seq_len, n_embd)
        unsplit_q, _, _ = debug_qkv.split(
            self.n_embd, dim=-1
        )  # shape (batch_size, seq_len, n_embd // 3)

        debug_qkv = debug_qkv.view(
            batch_size, seq_len, self.n_head, 3 * self.head_size
        )  # shape (batch_size, seq_len, 4, 96)
        q, _, _ = debug_qkv.split(
            self.head_size, dim=-1
        )  # shape (batch_size, seq_len, 4, 96 // 3)

        # Ensure correct weight and bias extraction
        weight = self.query_key_value.weight
        bias = self.query_key_value.bias

        q_weight, k_weight, v_weight = weight.chunk(3, dim=0)
        q_bias, k_bias, v_bias = bias.chunk(3, dim=0)

        manual_q = F.linear(x, q_weight, q_bias)
        manual_q = manual_q  # shape (batch_size, seq_len, n_embd)

        assert torch.allclose(unsplit_q, manual_q) # Passes
        print(torch.max(torch.abs(unsplit_q - manual_q)))  # tensor(0.)

        manual_q = manual_q.view(batch_size, seq_len, self.n_head, self.head_size)

        print(torch.max(torch.abs(q - manual_q)))  # tensor(35.6218)
        assert torch.allclose(q, manual_q) # AssertionError
        return manual_q

Solution

  • Your view operations are changing the layout of data in the tensor. This is leading to the discrepancy. Take a simple example:

    Create a dummy tensor representing the packed QKV values. The tensor has batch size and sequence length of 1, and int values to easily track how values move around.

    import torch
    
    d_emb = 1
    n_heads = 4
    
    # size (1, 1, d_emb*n_heads*3)
    debug_qkv = torch.arange(d_emb*n_heads*3)[None,None,:]
    
    print(debug_qkv.shape)
    > torch.Size([1, 1, 12])
    print(debug_qkv)
    > tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]]])
    

    Now we compute unsplit_q by splitting debug_qkv along the final axis. We see that unsplit_q has the values [0, 1, 2, 3]. This makes sense. The last dim of debug_qkv has values from 0 to 11. We split this into 3 contiguous chunks. The first chunk, unsplit_q, has values from 0 to 3.

    unsplit_q, _, _ = debug_qkv.split(n_heads, -1)
    
    print(unsplit_q.shape)
    > tensor([[[0, 1, 2, 3]]])
    print(unsplit_q)
    > torch.Size([1, 1, 4])
    

    Now we look at the view operation. In your code, this is the line debug_qkv = debug_qkv.view(batch_size, seq_len, self.n_head, 3 * self.head_size)

    We can see how the tensor layout has changed. The view operation fills out the n_head dimension first, then the final dimension. The value we are looking for - [0, 1, 2, 3] - has actually been split up.

    qkv_reshaped = debug_qkv.view(1, 1, n_heads, d_emb*3)
    print(qkv_reshaped.shape)
    > torch.Size([1, 1, 4, 3])
    print(qkv_reshaped)
    > tensor([[[[ 0,  1,  2],
               [ 3,  4,  5],
               [ 6,  7,  8],
               [ 9, 10, 11]]]])
    

    This propagates forward when we split q from qkv_reshaped

    q, _, _ = qkv_reshaped.split(d_emb, dim=-1)
    print(q)
    > tensor([[[[0],
               [3],
               [6],
               [9]]]])
    

    We can see clearly how the values are mixed up. unsplit_q has contiguous values [0, 1, 2, 3], while q has the reshaped values [0, 3, 6, 9].

    The solution in this case is to reshape debug_qkv to have the head dimension last

    qkv_reshaped2 = debug_qkv.view(1, 1, d_emb*3, n_heads)
    print(qkv_reshaped2)
    
    > tensor([[[[ 0,  1,  2,  3],
               [ 4,  5,  6,  7],
               [ 8,  9, 10, 11]]]])
    
    q2, _, _ = qkv_reshaped2.split(d_emb, dim=-2)
    print(q2)
    
    > tensor([[[[0, 1, 2, 3]]]])
    

    If you read the source code for pytorch's multiheadattention, you'll find a lot of operations that look like this:

    q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

    This operation reshapes with the head dim last because of this exact issue.