pythontensorflowkerastensorflow-probability

How to use values from previous Keras layer in convert_to_tensor_fn for TensorFlow Probability DistributionLambda


I have a Keras/TensorFlow Probability model where I would like to include values from the prior layer in the convert_to_tensor_fn parameter in the following DistributionLambda layer. Ideally, I wish I could do something like this:

from functools import partial
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow_probability as tfp
from typing import Union
tfd = tfp.distributions

zero_buffer = 1e-5


def quantile(s: tfd.Distribution, q: Union[tf.Tensor, float]) -> Union[tf.Tensor, float]:
    return s.quantile(q)


# 4 records (1st value represents CDF value, 
#            2nd represents location, 
#            3rd represents scale)
sample_input = tf.constant([[0.25, 0.0, 1.0], 
                            [0.5, 1.0, 0.5], 
                            [0.75, -1.0, 2.0], 
                            [0.95, 3.0, 2.5]], dtype=tf.float32)

# Build toy model for demonstration
input_layer = layers.Input(3)
dist = tfp.layers.DistributionLambda(
    make_distribution_fn=lambda t: tfd.Normal(loc=t[..., 1],
                                              scale=zero_buffer + tf.nn.softplus(t[..., 2])),
    convert_to_tensor_fn=lambda t, s: partial(quantile, q=t[..., 0])(s)
)(input_layer)
model = Model(input_layer, dist)

However, according to the documentation, the convert_to_tensor_fn is required to only take a tfd.Distribution as input; the convert_to_tensor_fn=lambda t, s: code doesn't work in the code above.

How can I access data from the prior layer in the convert_to_tensor_fn? I'm assuming there's a clever way to create a partial function, or something similar, to get this to work.

Outside of the Keras model framework, this is fairly easy to do using code similar to the example below:

# input data in Tensor Constant form
cdf_data = tf.constant([0.25, 0.5, 0.75, 0.95], dtype=tf.float32)
norm_mu = tf.constant([0.0, 1.0, -1.0, 3.0], dtype=tf.float32)
norm_scale = tf.constant([1.0, 0.5, 2.0, 2.5], dtype=tf.float32)

quant = partial(quantile, q=cdf_data)
norm = tfd.Normal(loc=norm_mu, scale=norm_scale)
quant(norm)

Output:

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([-0.6744898,  1.       ,  0.3489796,  7.112134 ], dtype=float32)>

Solution

  • I found a solution to this problem on my own, and decided to post it here.

    You can create a wrapper class for the tfp.Normal distribution that takes in the cdf value as an argument, and then you overwrite a couple of methods to do what you want. You especially need to overwrite the _sample_n method and replace it with the quantile function instead of a random draw from the distribution. The class would look something like this:

    import tensorflow as tf
    import tensorflow_probability as tfp
    from tensorflow_probability.python.internal import dtype_util, tensor_util, reparameterization, samplers
    from tensorflow_probability.python.internal import prefer_static as ps
    tfd = tfp.distributions
    
    
    class NormalWrapper(tfp.distributions.Normal):
        def __init__(self,
                     loc,
                     scale,
                     cdf_vals,
                     validate_args=False,
                     allow_nan_stats=True,
                     name='NormalCDF'):
            parameters = dict(locals())
            with tf.name_scope(name) as name:
                dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
                self._cdf_vals = tensor_util.convert_nonref_to_tensor(
                    cdf_vals, dtype=dtype, name='cdf_vals')
            super(NormalWrapper, self).__init__(loc=loc,
                                                scale=scale,
                                                validate_args=validate_args,
                                                allow_nan_stats=allow_nan_stats,
                                                name=name)
            self._parameters = parameters
    
        def _parameter_properties(self, dtype=tf.float32, num_classes=None):
            return dict(
                loc=tfp.util.ParameterProperties(),
                scale=tfp.util.ParameterProperties(
                    default_constraining_bijector_fn=(
                        lambda: tf.nn.softplus(low=dtype_util.eps(dtype)))),
                cdf_vals=tfp.util.ParameterProperties(),
            )
    
        @property
        def cdf_vals(self):
            return self._cdf_vals
    
        def _sample_n(self, n, seed=None):
            loc = tf.convert_to_tensor(self.loc)
            scale = tf.convert_to_tensor(self.scale)
            cdf_vals = tf.convert_to_tensor(self.cdf_vals)
            shape = ps.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale, cdf_vals=cdf_vals)], axis=0)
            return tf.reshape(self.quantile(cdf_vals), shape=shape)
    
    

    Once you have that class, you can create your DistributionLambda layer like this:

    dist = tfp.layers.DistributionLambda(
        make_distribution_fn=lambda t: NormalWrapper(loc=t[..., 1],
                                                     scale=zero_buffer + tf.nn.softplus(t[..., 2]),
                                                     cdf_vals=t[..., 0]),
    )(input_layer)