I am trying to learn how to sample truncated distributions. To begin with I decided to try a simple example I found here example
I didn't really understand the division by the CDF, therefore I decided to tweak the algorithm a bit. Being sampled is an exponential distribution for values x>0
Here is an example python code:
# Sample exponential distribution for the case x>0
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def pdf(x):
return x*np.exp(-x)
xvec=np.zeros(1000000)
x=1.
for i in range(1000000):
a=x+np.random.normal()
xs=x
if a > 0. :
xs=a
A=pdf(xs)/pdf(x)
if np.random.uniform()<A :
x=xs
xvec[i]=x
x=np.linspace(0,15,1000)
plt.plot(x,pdf(x))
plt.hist([x for x in xvec if x != 0],bins=150,normed=True)
plt.show()
The code above seems to work fine only for when using the condition if a > 0. :
, i.e. positive x
, choosing another condition (e.g. if a > 0.5 :
) produces wrong results.
Since my final goal was to sample a 2D-Gaussian - pdf on a truncated interval I tried extending the simple example using the exponential distribution (see the code below). Unfortunately, since the simple case didn't work, I assume that the code given below would yield wrong results.
I assume that all this can be done using the advanced tools of python. However, since my primary idea was to understand the principle behind, I would greatly appreciate your help to understand my mistake. Thank you for your help.
EDIT:
# code updated according to the answer of CrazyIvan
from scipy.stats import multivariate_normal
RANGE=100000
a=2.06072E-02
b=1.10011E+00
a_range=[0.001,0.5]
b_range=[0.01, 2.5]
cov=[[3.1313994E-05, 1.8013737E-03],[ 1.8013737E-03, 1.0421529E-01]]
x=a
y=b
j=0
for i in range(RANGE):
a_t,b_t=np.random.multivariate_normal([a,b],cov)
# accept if within bounds - all that is neded to truncate
if a_range[0]<a_t and a_t<a_range[1] and b_range[0]<b_t and b_t<b_range[1]:
print(dx,dy)
EDIT:
I changed the code by norming the analytic pdf according to this scheme, and according to the answers given by, @Crazy Ivan and @Leandro Caniglia , for the case where the bottom of the pdf is removed. That is dividing by (1-CDF(0.5)) since my accept condition is x>0.5
. This seems again to show some discrepancies. Again the mystery prevails ..
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def pdf(x):
return x*np.exp(-x)
# included the corresponding cdf
def cdf(x):
return 1. -np.exp(-x)-x*np.exp(-x)
xvec=np.zeros(1000000)
x=1.
for i in range(1000000):
a=x+np.random.normal()
xs=x
if a > 0.5 :
xs=a
A=pdf(xs)/pdf(x)
if np.random.uniform()<A :
x=xs
xvec[i]=x
x=np.linspace(0,15,1000)
# new part norm the analytic pdf to fix the area
plt.plot(x,pdf(x)/(1.-cdf(0.5)))
plt.hist([x for x in xvec if x != 0],bins=200,normed=True)
plt.savefig("test_exp.png")
plt.show()
It seems that this can be cured by choosing larger shift size
shift=15.
a=x+np.random.normal()*shift.
which is in general an issue of the Metropolis - Hastings. See the graph below:
Bottom line is that changing the shift size definitely improves the convergence. The misery is why, since the Gaussian is unbounded.
You say you want to learn the basic idea of sampling a truncated distribution, but your source is a blog post about Metropolis–Hastings algorithm? Do you actually need this "method for obtaining a sequence of random samples from a probability distribution for which direct sampling is difficult"? Taking this as your starting point is like learning English by reading Shakespeare.
For truncated normal, basic rejection sampling is all you need: generate samples for original distribution, reject those outside of bounds. As Leandro Caniglia noted, you should not expect truncated distribution to have the same PDF except on a shorter interval — this is plain impossible because the area under the graph of a PDF is always 1. If you cut off stuff from sides, there has to be more in the middle; the PDF gets rescaled.
It's quite inefficient to gather samples one by one, when you need 100000. I would grab 100000 normal samples at once, accept only those that fit; then repeat until I have enough. Example of sampling truncated normal between amin and amax:
import numpy as np
n_samples = 100000
amin, amax = -1, 2
samples = np.zeros((0,)) # empty for now
while samples.shape[0] < n_samples:
s = np.random.normal(0, 1, size=(n_samples,))
accepted = s[(s >= amin) & (s <= amax)]
samples = np.concatenate((samples, accepted), axis=0)
samples = samples[:n_samples] # we probably got more than needed, so discard extra ones
And here is the comparison with the PDF curve, rescaled by division by cdf(amax) - cdf(amin)
as explained above.
from scipy.stats import norm
_ = plt.hist(samples, bins=50, density=True)
t = np.linspace(-2, 3, 500)
plt.plot(t, norm.pdf(t)/(norm.cdf(amax) - norm.cdf(amin)), 'r')
plt.show()
Now we want to keep the first coordinate between amin and amax, and the second between bmin and bmax. Same story, except there will be a 2-column array and the comparison with bounds is done in a relatively sneaky way:
(np.min(s - [amin, bmin], axis=1) >= 0) & (np.max(s - [amax, bmax], axis=1) <= 0)
This means: subtract amin, bmin from each row and keep only the rows where both results are nonnegative (meaning we had a >= amin and b >= bmin). Also do a similar thing with amax, bmax. Accept only the rows that meet both criteria.
n_samples = 10
amin, amax = -1, 2
bmin, bmax = 0.2, 2.4
mean = [0.3, 0.5]
cov = [[2, 1.1], [1.1, 2]]
samples = np.zeros((0, 2)) # 2 columns now
while samples.shape[0] < n_samples:
s = np.random.multivariate_normal(mean, cov, size=(n_samples,))
accepted = s[(np.min(s - [amin, bmin], axis=1) >= 0) & (np.max(s - [amax, bmax], axis=1) <= 0)]
samples = np.concatenate((samples, accepted), axis=0)
samples = samples[:n_samples, :]
Not going to plot, but here are some values: naturally, within bounds.
array([[ 0.43150033, 1.55775629],
[ 0.62339265, 1.63506963],
[-0.6723598 , 1.58053835],
[-0.53347361, 0.53513105],
[ 1.70524439, 2.08226558],
[ 0.37474842, 0.2512812 ],
[-0.40986396, 0.58783193],
[ 0.65967087, 0.59755193],
[ 0.33383214, 2.37651975],
[ 1.7513789 , 1.24469918]])