numpytensorflowtensorflow-probability

TensorFlow Probability (tfp) equivalent of np.quantile()


I am trying to find a TensorFlow equivalent of np.quantile(). I have found tfp.stats.quantiles() (tfp stands for TensorFlow Probability). However, its constructs are a bit different from that of np.quantile().

Consider the following example:

import tensorflow_probability as tfp
import tensorflow as tf 
import numpy as np 

inputs = tf.random.normal((1, 4096, 4))

print("NumPy")
print(np.quantile(inputs.numpy(), q=0.9, axis=1, keepdims=False))

I am not sure from the TFP docs how I could write the above using tfp.stats.quantile(). I tried checking out the source code of both methods, but it didn't help.


Solution

  • Let me try to be more helpful here than I was on GitHub.

    There is a difference in behavior between np.quantile and tfp.stats.quantiles. The key difference here is that numpy.quantile will

    Compute the q-th quantile of the data along the specified axis.

    where q is the

    Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.

    and tfp.stats.quantiles

    Given a vector x of samples, this function estimates the cut points by returning num_quantiles + 1 cut points

    So you need to tell tfp.stats.quantiles how many quantiles you want and then select out the qth quantile. If it isn't clear how to do this just from the API, if you look at the source for tfp.stats.quantiles (for v0.19.0) we can see that it shows us how we can get a similar return structure as NumPy.

    For completeness, setting up a virtual environment with

    $ cat requirements.txt
    numpy==1.24.2
    tensorflow==2.11.0
    tensorflow-probability==0.19.0
    

    allows us to run

    import numpy as np
    import tensorflow as tf
    import tensorflow_probability as tfp
    
    inputs = tf.random.normal((1, 4096, 4), dtype=tf.float64)
    q = 0.9
    
    numpy_quantiles = np.quantile(inputs.numpy(), q=q, axis=1, keepdims=False)
    
    tfp_quantiles = tfp.stats.quantiles(
        inputs, num_quantiles=100, axis=1, interpolation="linear"
    )[int(q * 100)]
    
    assert np.allclose(numpy_quantiles, tfp_quantiles.numpy())
    
    print(f"{numpy_quantiles=}")
    # numpy_quantiles=array([[1.31727661, 1.2699167 , 1.28735237, 1.27137588]])
    print(f"{tfp_quantiles=}")
    # tfp_quantiles=<tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[1.31727661, 1.2699167 , 1.28735237, 1.27137588]])>