pytorchpytorch-lightningattention-modelself-attentionvision-transformer

This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to


refering to the attention maps in VIT transformers example in: https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old

This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to. and How and where in the code the x value is passed to the function my_forward.

def my_forward(x):
        B, N, C = x.shape

        qkv = attn_obj.qkv(x).reshape(
            B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0) 

Solution

  • This requires a little code inspection but you can easily find the implementation if you look in the right places. Let us start with your snippet.