machine-learningpytorchmatrix-multiplication

How is nn.Linear applied to a higher dimensional data?


I'm trying to understand what values are being multiplied when I have higher dimensional tensors:

inp = torch.rand(1,2,3) # B,C,W
linear = nn.Linear(3,4)
out = linear(inp)
print(out.shape)
>>> torch.Size([1, 2, 4])

inp = torch.rand(1,2,3,4) # B,C,W,H
linear = nn.Linear(4,5)
out = linear(inp)
print(out.shape)
>>> torch.Size([1, 2, 3, 5])

It seems like only the last dimension is being changed, but when I try to manually multiply the linear weights (linear.weight.data) with each inp's last dimension, I can't get to the correct answer (seems like all the values are changing and only the last dimension's size is being modify somehow).


Solution

  • The linear layer implements an affine transformation, that is a matrix multiplication and translation by a bias vector, which is applied to all matrices in any tensor viewed as some layout of matrices.


    It is better to consider your second example for demonstration. Your (slightly modified) second example:

    inp = torch.rand(1,2,3,4) # B,C,W,H
    print(f'inp ({inp.shape}):\n{inp}')
    linear = nn.Linear(4,5)
    print(f'weight ({linear.weight.shape}):\n{linear.weight}')
    print(f'bias {linear.bias.shape}:\n{linear.bias}')
    out = linear(inp)
    print(f'out ({out.shape}):\n{out}')
    

    prints (might be different for you due to randomness)

    inp (torch.Size([1, 2, 3, 4])):
    tensor([[[[0.3340, 0.6843, 0.6702, 0.9667],
              [0.9990, 0.3094, 0.7772, 0.7851],
              [0.3004, 0.6993, 0.3088, 0.5238]],
    
             [[0.5257, 0.9793, 0.2408, 0.4065],
              [0.7183, 0.8921, 0.8280, 0.1272],
              [0.7826, 0.2930, 0.1266, 0.8724]]]])
    weight (torch.Size([5, 4])):
    Parameter containing:
    tensor([[ 0.2177, -0.0575, -0.4756, -0.1297],
            [ 0.3632, -0.2986, -0.0157, -0.2817],
            [ 0.4323,  0.3205,  0.2895, -0.1527],
            [ 0.2368,  0.4018, -0.2126,  0.4732],
            [-0.0158, -0.4908,  0.3854, -0.4685]], requires_grad=True)
    bias torch.Size([5]):
    Parameter containing:
    tensor([-0.0749, -0.1002, -0.3814,  0.2213, -0.4468], requires_grad=True)
    out (torch.Size([1, 2, 3, 5])):
    tensor([[[[-0.4856, -0.4660,  0.0287,  0.8903, -0.9825],
              [-0.3466, -0.0631,  0.2547,  0.7885, -0.6827],
              [-0.2645, -0.3523, -0.0180,  0.7556, -0.9211]],
    
             [[-0.1840, -0.3200,  0.1673,  0.8805, -1.0334],
              [-0.3801, -0.1546,  0.4352,  0.6340, -0.6364],
              [-0.0947, -0.1511, -0.0458,  0.9102, -0.9628]]]],
           grad_fn=<ViewBackward0>)
    

    So here inp contains two 3x4 matrices embedded into a 4-dimensional tensor according to the first two dimensions 1x2 layout. The linear layer multiplies all matrices by the weights, adds the bias vector to all of them, and finally stacks the resulted two 3x5 matrices according to the 1x2 layout. You can perform the linear(inp) operation equivalently as

    torch.tensordot(inp, linear.weight.T, dims=1) + linear.bias
    

    which will print the same result as out. The dim=1 argument of torch.tensordot specifies how many dimensions the operation "consumes" which is 1 for the matrix product.


    To make it even clearer, use can use only matrix products by torch.matmul and perform the unstacking and stacking manually:

    mat0 = torch.matmul(inp[0][0], linear.weight.T) + linear.bias
    mat1 = torch.matmul(inp[0][1], linear.weight.T) + linear.bias
    print(f'mat0 ({mat0.shape}):\n{mat0}')
    print(f'mat1 ({mat1.shape}):\n{mat1}')
    # =>
    # mat0 (torch.Size([3, 5])):
    # tensor([[-0.4856, -0.4660,  0.0287,  0.8903, -0.9825],
    #         [-0.3466, -0.0631,  0.2547,  0.7885, -0.6827],
    #         [-0.2645, -0.3523, -0.0180,  0.7556, -0.9211]], grad_fn=# <AddBackward0>)
    # mat1 (torch.Size([3, 5])):
    # tensor([[-0.1840, -0.3200,  0.1673,  0.8805, -1.0334],
    #         [-0.3801, -0.1546,  0.4352,  0.6340, -0.6364],
    #         [-0.0947, -0.1511, -0.0458,  0.9102, -0.9628]], grad_fn=# <AddBackward0>)
    

    which can be stacked by torch.stack into a 2x3x5 tensor or a 1x2x3x5 tensor (as above):

    t1 = torch.stack([mat0, mat1])
    print(f't1 ({t1.shape}):\n{t1}')
    # =>
    # t1 (torch.Size([2, 3, 5])):
    # tensor([[[-0.4856, -0.4660,  0.0287,  0.8903, -0.9825],
    #          [-0.3466, -0.0631,  0.2547,  0.7885, -0.6827],
    #          [-0.2645, -0.3523, -0.0180,  0.7556, -0.9211]],
    #
    #         [[-0.1840, -0.3200,  0.1673,  0.8805, -1.0334],
    #          [-0.3801, -0.1546,  0.4352,  0.6340, -0.6364],
    #          [-0.0947, -0.1511, -0.0458,  0.9102, -0.9628]]],
    #        grad_fn=<StackBackward0>)
    
    t2 = torch.stack([t1])
    print(f't2 ({t2.shape}):\n{t2}')
    # =>
    # t2 (torch.Size([1, 2, 3, 5])):
    # tensor([[[[-0.4856, -0.4660,  0.0287,  0.8903, -0.9825],
    #           [-0.3466, -0.0631,  0.2547,  0.7885, -0.6827],
    #           [-0.2645, -0.3523, -0.0180,  0.7556, -0.9211]],
    #
    #          [[-0.1840, -0.3200,  0.1673,  0.8805, -1.0334],
    #           [-0.3801, -0.1546,  0.4352,  0.6340, -0.6364],
    #           [-0.0947, -0.1511, -0.0458,  0.9102, -0.9628]]]],
    #        grad_fn=<StackBackward0>)