pythonscipystatisticstransportnumpy-random

1D Wasserstein distance in Python


The formula below is a special case of the Wasserstein distance/optimal transport when the source and target distributions, x and y (also called marginal distributions) are 1D, that is, are vectors.

enter image description here

where F^{-1} are inverse probability distribution functions of the cumulative distributions of the marginals u and v, derived from real data called x and y, both generated from the normal distribution:

import numpy as np
from numpy.random import randn
import scipy.stats as ss

n = 100
x = randn(n)
y = randn(n)

How can the integral in the formula be coded in python and scipy? I'm guessing the x and y have to be converted to ranked marginals, which are non-negative and sum to 1, while Scipy's ppf could be used to calculate the inverse F^{-1}'s?


Solution

  • Note that when n gets large we have that a sorted set of n samples approaches the inverse CDF sampled at 1/n, 2/n, ..., n/n. E.g.:

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.stats import norm
    plt.plot(norm.ppf(np.linspace(0, 1, 1000)), label="invcdf")
    plt.plot(np.sort(np.random.normal(size=1000)), label="sortsample")
    plt.legend()
    plt.show()
    

    plot

    Also note that your integral from 0 to 1 can be approximated as a sum over 1/n, 2/n, ..., n/n.

    Thus we can simply answer your question:

    def W(p, u, v):
        assert len(u) == len(v)
        return np.mean(np.abs(np.sort(u) - np.sort(v))**p)**(1/p)
    

    Note that if len(u) != len(v) you can still apply the method with linear interpolation:

    def W(p, u, v):
        u = np.sort(u)
        v = np.sort(v)
        if len(u) != len(v):
            if len(u) > len(v): u, v = v, u
            us = np.linspace(0, 1, len(u))
            vs = np.linspace(0, 1, len(v))
            u = np.linalg.interp(u, us, vs)
        return np.mean(np.abs(u - v)**p)**(1/p)
    

    An alternative method if you have prior information about the sort of distribution of your data, but not its parameters, is to find the best fitting distribution on your data (e.g. with scipy.stats.norm.fit) for both u and v and then do the integral with the desired precision. E.g.:

    from scipy.stats import norm as gauss
    def W_gauss(p, u, v, num_steps):
        ud = gauss(*gauss.fit(u))
        vd = gauss(*gauss.fit(v))
        z = np.linspace(0, 1, num_steps, endpoint=False) + 1/(2*num_steps)
        return np.mean(np.abs(ud.ppf(z) - vd.ppf(z))**p)**(1/p)