tensorflowdeep-learningpytorchdistributiontensorflow-probability

Convert a tensorflow script to pytorch (TransformedDistribution)


I am trying to rewrite a tensorflow script in pytorch. I have a problem finding the equivalent part in torch for the following line from this script:

import tensorflow_probability as tfp
tfd = tfp.distributions
a_distribution = tfd.TransformedDistribution(
        distribution=tfd.Normal(loc=0.0, scale=1.0),
        bijector=tfp.bijectors.Chain([
            tfp.bijectors.AffineScalar(shift=self._means,
                                       scale=self._mags),
            tfp.bijectors.Tanh(),
            tfp.bijectors.AffineScalar(shift=mean, scale=std),
        ]),
        event_shape=[mean.shape[-1]],
        batch_shape=[mean.shape[0]])

In particular, I have a huge problem for replacing the tfp.bijectors.Chain component. I wrote the following lines in torch, but I am wondering whether these lines in pytorch compatible with the above tensorflow code and whether I can specify the batch_shape somewhere?

base_distribution = torch.normal(0.0, 1.0)
transforms = torch.distributions.transforms.ComposeTransform([torch.distributions.transforms.AffineTransform(loc=self._action_means, scale=self._action_mag, event_dim=mean.shape[-1]), torch.nn.Tanh(),torch.distributions.transforms.AffineTransform(loc=mean, scale=std, event_dim=mean.shape[-1])])
a_distribution = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms)
 

Any solution?


Solution

  • In Pytorch, the base distribution class Distribution expects both a batch_shape and a event_shape parameter. Now notice that the subclass TransformedDistribution does not take such parameters (src code). That's because they are inferred from the base distribution class provided on initialization: see here and here.

    You already found out about AffineTransform and ComposeTransform. Keep in mind you must stick with classes from the torch.distributions.


    Here is a minimal example using your transformation pipeline:

    Imports:
    from torch.distributions.normal import Normal
    from torch.distributions import transforms as tT
    from torch.distributions.transformed_distribution import TransformedDistribution
    
    Parameters:
    mean = torch.rand(2,2)
    std = 1
    _action_means, _action_mag = 0, 1
    event_dim=mean.shape[-1]
    
    Distribution definition:
    a_distribution = TransformedDistribution(
        base_distribution=Normal(loc=torch.full_like(mean, 0), 
                                 scale=torch.full_like(mean, 1)), 
        transforms=tT.ComposeTransform([
            tT.AffineTransform(loc=_action_means, scale=_action_mag, event_dim=event_dim), 
            tT.TanhTransform(),
            tT.AffineTransform(loc=mean, scale=std, event_dim=event_dim)]))