I have a sampled path of a stochastic process starting from an initial point:
class SDE_ou_1d(nn.Module):
def __init__(self):
self.sde_type = "into"
self.noise_type = "diagonal"
def f(self, t, y): #drift
return -y
def g(self, t, y): #vol
return torch.ones_like(y)
t_vec = torch.linspace(0, 1, 100) #time array
mySDE = SDE_ou_1d()
x0 = torch.zeros(1, 1, requires_grad=True).to(t_vec)
X_t = torchsde.sdeint(mySDE, x0, t_vec, method = 'euler')
and I would like to measure the gradient with respect to the initial condition using torch.autograd.grad()
, and get an output with the same shape as X_t
i.e. 100x1.
This gives the change in the path at every time point
X_grad = torch.autograd.grad(outputs=X_t, inputs=x0,
create_graph=False, retain_graph=True, only_inputs=True, allow_unused=True)[0]
the issue is that the gradient is a sum over all values of t
I can do this with a for loop, but it is very slow and not practical:
X_grad_loop = torch.zeros_like(X_t)
for i in range(X_t.shape[0]): # Loop over the first dimension of X_t which is time
grad_i = torch.autograd.grad(outputs=X_t[i,...], inputs=x0,
create_graph=False, retain_graph=True, only_inputs=True, allow_unused=True)[0]
X_grad_loop[i,...] = grad_i
is there a way to compute this gradient with torch.autograd.grad()
and no loop?
I finally found a way to make it work using jacrev
from functorch import jacrev
def sample_SDE(t_vec, N_sample, x0):
x0_per_sample = x0.unsqueeze(0).expand(N_sample, -1)
return torchsde.sdeint(mySDE, x0_per_sample, t_vec, method = 'euler', bm = bm)
# Here we remove the N_sample dimension otherwise jacrev will think of it as another variable to differentiate through
x0 = torch.zeros(Dx, requires_grad = True).to(t_vec)
# then we differentiate
Jacob_autograd = torch.func.jacrev(sample_SDE, argnums=2)(t_vec, N_sample, x0)
in the last line, we specify that x0 (argunum 2) is the variable to differentiate with respect to.