Consider the definition of the observe
statement from Probabilistic Programming, as defined in [1]:
The observe statement blocks runs which do not satisfy the boolean expression E and does not permit those executions to happen.
Now, consider the following theoretical program:
def f():
x ~ Normal(0, 1)
observe(x > 0) # only allow samples x > 0
return x
which should return values from the truncated Normal(0, 1)
distribution.
Therefore, my question is: how can observe
be achieved in TensorFlow Probability, or what’s its equivalent? Note that observe
's argument should be any (symbolic) boolean expression E
: (e.g. lambda x: x > 0
).
NOTE: Sure, for the program above one can use the HalfNormal
distribution, but I am using it for a practical example of observe
.
[1] Gordon, Andrew D., et al. “Probabilistic programming.” Proceedings of the on Future of Software Engineering . 2014. 167-181.
The only way to achieve this in general is using a rejection sampler, which is expensive. And then you don't have a tractable density. In general TFP requires all our distributions to have a tractable density (i.e. dist.prob(x)
). We do have an autodiff friendly TruncatedNormal
, or as you note HalfNormal
.
If you wanted to implement something else it could be as simple as:
class Rejection(tfd.Distribution):
def __init__(self, underlying, condition, name=None):
self._u = underlying
self._c = condition
super().__init__(dtype=underlying.dtype,
name=name or f'rejection_{underlying}',
reparameterization_type=tfd.NOT_REPARAMETERIZED,
validate_args=underlying.validate_args,
allow_nan_stats=underlying.allow_nan_stats)
def _batch_shape(self):
return self._u.batch_shape
def _batch_shape_tensor(self):
return self._u.batch_shape_tensor()
def _event_shape(self):
return self._u.event_shape
def _event_shape_tensor(self):
return self._u.event_shape_tensor()
def _sample_n(self, n, seed=None):
return tf.while_loop(
lambda samples: not tf.reduce_all(self._c(samples)),
lambda samples: (tf.where(self._c(samples), samples, self._u.sample(n, seed=seed)),),
(self._u.sample(n, seed=seed),))[0]
d = Rejection(tfd.Normal(0,1), lambda x: x > -.3)
s = d.sample(100).numpy()
print(s.min())