pythonpytorchautogradstochastic-process

individual gradients with torch.autograd.grad without sum over second variable


I have a sampled path of a stochastic process starting from an initial point:


class SDE_ou_1d(nn.Module):
    def __init__(self):
        super().__init__()
        self.sde_type = "into"
        self.noise_type = "diagonal"

    def f(self, t, y): #drift
        return -y

    def g(self, t, y): #vol
        return torch.ones_like(y)

t_vec = torch.linspace(0, 1, 100)  #time array
mySDE = SDE_ou_1d()

x0 = torch.zeros(1, 1, requires_grad=True).to(t_vec)
X_t = torchsde.sdeint(mySDE, x0, t_vec, method = 'euler')

and I would like to measure the gradient with respect to the initial condition using torch.autograd.grad(), and get an output with the same shape as X_t i.e. 100x1. This gives the change in the path at every time point


X_grad = torch.autograd.grad(outputs=X_t, inputs=x0,
                           grad_outputs=torch.ones_like(X_t),
                           create_graph=False, retain_graph=True, only_inputs=True, allow_unused=True)[0]

the issue is that the gradient is a sum over all values of t.

I can do this with a for loop, but it is very slow and not practical:

X_grad_loop = torch.zeros_like(X_t)

for i in range(X_t.shape[0]):  # Loop over the first dimension of X_t which is time
    grad_i = torch.autograd.grad(outputs=X_t[i,...], inputs=x0,
                                    grad_outputs=torch.ones_like(X_t[i,...]),
                                    create_graph=False, retain_graph=True, only_inputs=True, allow_unused=True)[0]
    X_grad_loop[i,...] = grad_i

is there a way to compute this gradient with torch.autograd.grad() and no loop? thanks


Solution

  • I finally found a way to make it work using jacrev mentioned.

    from functorch import jacrev
        
    def sample_SDE(t_vec, N_sample, x0):
        x0_per_sample = x0.unsqueeze(0).expand(N_sample, -1)
        return torchsde.sdeint(mySDE, x0_per_sample, t_vec, method = 'euler', bm = bm)
    
    # Here we remove the N_sample dimension otherwise jacrev will think of it as another variable to differentiate through
    x0 = torch.zeros(Dx, requires_grad = True).to(t_vec)
    
    # then we differentiate
    Jacob_autograd = torch.func.jacrev(sample_SDE, argnums=2)(t_vec, N_sample, x0)
    

    in the last line, we specify that x0 (argunum 2) is the variable to differentiate with respect to.