Does Flux.jl
have an equivalent to rsample
in PyTorch
that automatically implements these stochastic/policy gradients. That way the reparameterized sample becomes differentiable.
After researching and asking about it on Julia discourse, it seems that there is no such thing as rsample
in Julia to simplify the reparametrization trick. However, there seem to be certain works that have not been shared yet, but they might be soon.
https://discourse.julialang.org/t/reparametrization-trick-in-flux-jl/100489