pythonpytorchbackpropagation

can't find the inplace operation: one of the variables needed for gradient computation has been modified by an inplace operation


I am trying to compute a loss on the jacobian of the network (i.e. to perform double backprop), and I get the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I can't find the inplace operation in my code, so I don't know which line to fix.

*The error occurs in the last line:

loss3.backward()

inputs_reg = Variable(data, requires_grad=True)
output_reg = self.model.forward(inputs_reg)

num_classes = output.size()[1]
jacobian_list = []
grad_output = torch.zeros(*output_reg.size())

if inputs_reg.is_cuda:
    grad_output = grad_output.cuda()
    jacobian_list = jacobian.cuda()

for i in range(10):

    zero_gradients(inputs_reg)
    grad_output.zero_()
    grad_output[:, i] = 1
    jacobian_list.append(torch.autograd.grad(outputs=output_reg,
                                      inputs=inputs_reg,
                                      grad_outputs=grad_output,
                                      only_inputs=True,
                                      retain_graph=True,
                                      create_graph=True)[0])


jacobian = torch.stack(jacobian_list, dim=0)
loss3 = jacobian.norm()
loss3.backward()

Solution

  • grad_output.zero_() is in-place and so is grad_output[:, i-1] = 0. In-place means "modify a tensor instead of returning a new one, which has the modifications applied". An example solution which is not in-place is torch.where. An example use to zero out the 1st column

    import torch
    t = torch.randn(3, 3)
    ixs = torch.arange(3, dtype=torch.int64)
    zeroed = torch.where(ixs[None, :] == 1, torch.tensor(0.), t)
    
    zeroed
    tensor([[-0.6616,  0.0000,  0.7329],
            [ 0.8961,  0.0000, -0.1978],
            [ 0.0798,  0.0000, -1.2041]])
    
    t
    tensor([[-0.6616, -1.6422,  0.7329],
            [ 0.8961, -0.9623, -0.1978],
            [ 0.0798, -0.7733, -1.2041]])
    

    Notice how t retains the values it had before and zeroed has the values you want.