I'm trying to understand what values are being multiplied when I have higher dimensional tensors:
inp = torch.rand(1,2,3) # B,C,W
linear = nn.Linear(3,4)
out = linear(inp)
print(out.shape)
>>> torch.Size([1, 2, 4])
inp = torch.rand(1,2,3,4) # B,C,W,H
linear = nn.Linear(4,5)
out = linear(inp)
print(out.shape)
>>> torch.Size([1, 2, 3, 5])
It seems like only the last dimension is being changed, but when I try to manually multiply the linear weights (linear.weight.data
) with each inp
's last dimension, I can't get to the correct answer (seems like all the values are changing and only the last dimension's size is being modify somehow).
The linear layer implements an affine transformation, that is a matrix multiplication and translation by a bias vector, which is applied to all matrices in any tensor viewed as some layout of matrices.
It is better to consider your second example for demonstration. Your (slightly modified) second example:
inp = torch.rand(1,2,3,4) # B,C,W,H
print(f'inp ({inp.shape}):\n{inp}')
linear = nn.Linear(4,5)
print(f'weight ({linear.weight.shape}):\n{linear.weight}')
print(f'bias {linear.bias.shape}:\n{linear.bias}')
out = linear(inp)
print(f'out ({out.shape}):\n{out}')
prints (might be different for you due to randomness)
inp (torch.Size([1, 2, 3, 4])):
tensor([[[[0.3340, 0.6843, 0.6702, 0.9667],
[0.9990, 0.3094, 0.7772, 0.7851],
[0.3004, 0.6993, 0.3088, 0.5238]],
[[0.5257, 0.9793, 0.2408, 0.4065],
[0.7183, 0.8921, 0.8280, 0.1272],
[0.7826, 0.2930, 0.1266, 0.8724]]]])
weight (torch.Size([5, 4])):
Parameter containing:
tensor([[ 0.2177, -0.0575, -0.4756, -0.1297],
[ 0.3632, -0.2986, -0.0157, -0.2817],
[ 0.4323, 0.3205, 0.2895, -0.1527],
[ 0.2368, 0.4018, -0.2126, 0.4732],
[-0.0158, -0.4908, 0.3854, -0.4685]], requires_grad=True)
bias torch.Size([5]):
Parameter containing:
tensor([-0.0749, -0.1002, -0.3814, 0.2213, -0.4468], requires_grad=True)
out (torch.Size([1, 2, 3, 5])):
tensor([[[[-0.4856, -0.4660, 0.0287, 0.8903, -0.9825],
[-0.3466, -0.0631, 0.2547, 0.7885, -0.6827],
[-0.2645, -0.3523, -0.0180, 0.7556, -0.9211]],
[[-0.1840, -0.3200, 0.1673, 0.8805, -1.0334],
[-0.3801, -0.1546, 0.4352, 0.6340, -0.6364],
[-0.0947, -0.1511, -0.0458, 0.9102, -0.9628]]]],
grad_fn=<ViewBackward0>)
So here inp
contains two 3x4 matrices embedded into a 4-dimensional tensor according to the first two dimensions 1x2 layout. The linear layer multiplies all matrices by the weights, adds the bias vector to all of them, and finally stacks the resulted two 3x5 matrices according to the 1x2 layout. You can perform the linear(inp)
operation equivalently as
torch.tensordot(inp, linear.weight.T, dims=1) + linear.bias
which will print the same result as out
.
The dim=1
argument of torch.tensordot
specifies how many dimensions the operation "consumes" which is 1 for the matrix product.
To make it even clearer, use can use only matrix products by torch.matmul
and perform the unstacking and stacking manually:
mat0 = torch.matmul(inp[0][0], linear.weight.T) + linear.bias
mat1 = torch.matmul(inp[0][1], linear.weight.T) + linear.bias
print(f'mat0 ({mat0.shape}):\n{mat0}')
print(f'mat1 ({mat1.shape}):\n{mat1}')
# =>
# mat0 (torch.Size([3, 5])):
# tensor([[-0.4856, -0.4660, 0.0287, 0.8903, -0.9825],
# [-0.3466, -0.0631, 0.2547, 0.7885, -0.6827],
# [-0.2645, -0.3523, -0.0180, 0.7556, -0.9211]], grad_fn=# <AddBackward0>)
# mat1 (torch.Size([3, 5])):
# tensor([[-0.1840, -0.3200, 0.1673, 0.8805, -1.0334],
# [-0.3801, -0.1546, 0.4352, 0.6340, -0.6364],
# [-0.0947, -0.1511, -0.0458, 0.9102, -0.9628]], grad_fn=# <AddBackward0>)
which can be stacked by torch.stack
into a 2x3x5 tensor or a 1x2x3x5 tensor (as above):
t1 = torch.stack([mat0, mat1])
print(f't1 ({t1.shape}):\n{t1}')
# =>
# t1 (torch.Size([2, 3, 5])):
# tensor([[[-0.4856, -0.4660, 0.0287, 0.8903, -0.9825],
# [-0.3466, -0.0631, 0.2547, 0.7885, -0.6827],
# [-0.2645, -0.3523, -0.0180, 0.7556, -0.9211]],
#
# [[-0.1840, -0.3200, 0.1673, 0.8805, -1.0334],
# [-0.3801, -0.1546, 0.4352, 0.6340, -0.6364],
# [-0.0947, -0.1511, -0.0458, 0.9102, -0.9628]]],
# grad_fn=<StackBackward0>)
t2 = torch.stack([t1])
print(f't2 ({t2.shape}):\n{t2}')
# =>
# t2 (torch.Size([1, 2, 3, 5])):
# tensor([[[[-0.4856, -0.4660, 0.0287, 0.8903, -0.9825],
# [-0.3466, -0.0631, 0.2547, 0.7885, -0.6827],
# [-0.2645, -0.3523, -0.0180, 0.7556, -0.9211]],
#
# [[-0.1840, -0.3200, 0.1673, 0.8805, -1.0334],
# [-0.3801, -0.1546, 0.4352, 0.6340, -0.6364],
# [-0.0947, -0.1511, -0.0458, 0.9102, -0.9628]]]],
# grad_fn=<StackBackward0>)