I am calculating the hessian of a ( 2,2) linear model using functorch.hessian
as follows
model = torch.nn.Linear(2,2).to(device)
inputs = torch.rand(1,2).to(device)
criterion = torch.nn.CrossEntropyLoss()
target=torch.ones(len(inputs), dtype=torch.long).to(device)
func, func_params = functorch.make_functional(model)
def loss(params):
out = func(params, inputs)
return criterion(out, target)
H=functorch.hessian(loss)(func_params)
Since I have 2 inputs and 2 outputs so we have basically 4 parameters I was expecting to see the second derivative with respect to these 4 parameters but the output looks different or not understandable. For example, the output of the code above is the following:
((tensor([[[[ 0.0217, 0.0701],
[-0.0217, -0.0701]],
[[ 0.0701, 0.2266],
[-0.0701, -0.2266]]],
[[[-0.0217, -0.0701],
[ 0.0217, 0.0701]],
[[-0.0701, -0.2266],
[ 0.0701, 0.2266]]]], device='cuda:0', grad_fn=<ViewBackward0>),
tensor([[[ 0.0718, -0.0718],
[ 0.2321, -0.2321]],
[[-0.0718, 0.0718],
[-0.2321, 0.2321]]], device='cuda:0', grad_fn=<ViewBackward0>)),
(tensor([[[ 0.0718, 0.2321],
[-0.0718, -0.2321]],
[[-0.0718, -0.2321],
[ 0.0718, 0.2321]]], device='cuda:0', grad_fn=<ViewBackward0>),
tensor([[ 0.2377, -0.2377],
[-0.2377, 0.2377]], device='cuda:0', grad_fn=<ViewBackward0>)))
Does anyone have an idea what's going on? and how can I calculate the trace of the hessian matrix in this case?
Model weights shape:
torch.Size([2, 2]) <- w1
torch.Size([2]) <- w2
Jacobian shape: (first dim is your batch size)
J[0].shape = [1, 2, 2] <- dLoss with respect to w1
J[1].shape = [1, 2] <- dLoss with respect to w2
Hessian shape: (first dim is your batch size)
H[0][0].shape = [1, 2, 2, 2, 2] <- dJ[0] with respect to w1
H[0][1].shape = [1, 2, 2, 2, 2] <- dJ[0] with respect to w2
H[1][0].shape = [1, 2, 2, 2] <- dJ[1] with respect to w1
H[1][1].shape = [1, 2, 2, 2] <- dJ[1] with respect to w2
It helps to notice that the shape of H[0] have the same first values than J[0], same for H[1] and J[1]. You are deriving twice with respect to the parameters, thus you repeat the weights shape again.