I am trying to rewrite a tensorflow script in pytorch. I have a problem finding the equivalent part in torch
for the following line from this script:
import tensorflow_probability as tfp
tfd = tfp.distributions
a_distribution = tfd.TransformedDistribution(
distribution=tfd.Normal(loc=0.0, scale=1.0),
bijector=tfp.bijectors.Chain([
tfp.bijectors.AffineScalar(shift=self._means,
scale=self._mags),
tfp.bijectors.Tanh(),
tfp.bijectors.AffineScalar(shift=mean, scale=std),
]),
event_shape=[mean.shape[-1]],
batch_shape=[mean.shape[0]])
In particular, I have a huge problem for replacing the tfp.bijectors.Chain
component.
I wrote the following lines in torch
, but I am wondering whether these lines in pytorch compatible with the above tensorflow
code and whether I can specify the batch_shape
somewhere?
base_distribution = torch.normal(0.0, 1.0)
transforms = torch.distributions.transforms.ComposeTransform([torch.distributions.transforms.AffineTransform(loc=self._action_means, scale=self._action_mag, event_dim=mean.shape[-1]), torch.nn.Tanh(),torch.distributions.transforms.AffineTransform(loc=mean, scale=std, event_dim=mean.shape[-1])])
a_distribution = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms)
Any solution?
In Pytorch, the base distribution class Distribution
expects both a batch_shape
and a event_shape
parameter. Now notice that the subclass TransformedDistribution
does not take such parameters (src code). That's because they are inferred from the base distribution class provided on initialization: see here and here.
You already found out about AffineTransform
and ComposeTransform
. Keep in mind you must stick with classes from the torch.distributions
.
This holds for torch.normal
which should be replaced with Normal
. With this class, the shape is inferred from the provided loc
and scale
tensors.
And nn.Tanh
which should be replaced with TanhTransform
.
Here is a minimal example using your transformation pipeline:
from torch.distributions.normal import Normal
from torch.distributions import transforms as tT
from torch.distributions.transformed_distribution import TransformedDistribution
mean = torch.rand(2,2)
std = 1
_action_means, _action_mag = 0, 1
event_dim=mean.shape[-1]
a_distribution = TransformedDistribution(
base_distribution=Normal(loc=torch.full_like(mean, 0),
scale=torch.full_like(mean, 1)),
transforms=tT.ComposeTransform([
tT.AffineTransform(loc=_action_means, scale=_action_mag, event_dim=event_dim),
tT.TanhTransform(),
tT.AffineTransform(loc=mean, scale=std, event_dim=event_dim)]))