
Scipy's multivariate earth mover distance not working as intended?

I am using scipy's multivariate earth mover distance function wasserstein_distance_nd. I did a quick sanity check that confused me: Given that I draw two Gaussian multivariate samples from the same distribution, I should get an earth mover distance that is close to 0. However, I am getting something large (e.g., 12). Why is this happening? I have tested this with the one dimensional case and I also got something similar (here, the distance produced is always positive).

Code that I used is given as follows:

import numpy as np
from scipy.stats import wasserstein_distance_nd

mean = np.zeros(100)
cov = np.eye(100)
size = 100
sample1 = np.random.multivariate_normal(mean, cov, size)
sample2 = np.random.multivariate_normal(mean, cov, size)

print("EMD", wasserstein_distance_nd(sample1, sample2))
# output: EMD 12.293968193381374

'''single dimension'''
import numpy as np
from scipy.stats import wasserstein_distance

mean = 0
var = 1
size = 100
sample1 = np.random.normal(mean, np.sqrt(var), size)
sample2 = np.random.normal(mean, np.sqrt(var), size)
dist = wasserstein_distance(sample1, sample2)

print("wasserstein_distance", dist)


  • Your multivariate_normal example is in a 100-dimensional space. You can think of the reason you are getting large distances from an intuitive perspective: there is too much variety possible in a 100-dimensional space for the two samples to be very similar.

    Some more motivation:

    For a rigorous explanation, seek help on Math Stack Exchange. My point, though, is just that intuition about magnitudes of distances from a 1 dimensional case just won't work well in 100 dimensions.

    If it helps, you can confirm that wasserstein_distance_nd agrees with wasserstein_distance in 1D.

    import numpy as np
    from scipy.stats import wasserstein_distance
    mean = 0
    var = 1
    size = 100
    sample1 = np.random.normal(mean, np.sqrt(var), size)
    sample2 = np.random.normal(mean, np.sqrt(var), size)
    ref = wasserstein_distance(sample1, sample2)
    res = wasserstein_distance_nd(sample1[:, np.newaxis], sample2[:, np.newaxis])
    print("wasserstein_distance", res, ref)
    # wasserstein_distance 0.2109383092226257 0.21093830922262574