pythonpytorchstatisticsnumerical-methods

Numerically stable noncentral chi-squared distribution in torch?


I need numerically stable non-central chi2 distribution in torch.


Solution

  • Following description from https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution

    import torch
    import matplotlib.pyplot as plt
    plt.style.use('dark_background')
    
    # log of bezel function approximation
    # the sum must go to infinity, but we stop at j=100
    def bezel(v,y,infinity=100):
        if not isinstance(y,torch.Tensor):
            y = torch.tensor(y)
        if not isinstance(v,torch.Tensor):
            v = torch.tensor(v)
        j = torch.arange(0,infinity)
        bottom = torch.lgamma(j+v+1)+torch.lgamma(j+1)
        top = 2*j*(0.5*y.unsqueeze(-1)).log()
        mult = (top-bottom)
        return (v*(y/2).log().unsqueeze_(-1)+mult)
    
    def noncentral_chi2(x,mu,k):
        if not isinstance(mu,torch.Tensor):
            mu = torch.tensor(mu)
        if not isinstance(k,torch.Tensor):
            k = torch.tensor(k)
        if not isinstance(x,torch.Tensor):
            x = torch.tensor(x)
        
        # the key trick is to use log operations instead of * and / as much as possible
        bezel_out = bezel(0.5*k-1,(mu*x).sqrt())
        x=x.unsqueeze_(-1)
        return (torch.tensor(0.5).log()+(-0.5*(x+mu))+(x.log()-mu.log())*(0.25*k-0.5)+bezel_out).exp().sum(-1)
    
    # count of normal random variables that we will sum
    loc = torch.rand((5))
    normal = torch.distributions.Normal(loc,1)
    
    # distribution parameter, also named as lambda
    mu = (loc**2).sum()
    
    # count of simulated sums
    events = 5000
    Xs = normal.sample((events,))
    
    # chi-square distribution
    Y = (Xs**2).sum(-1)
    
    t = torch.linspace(0.1,Y.max()+10,100)
    dist = noncentral_chi2(t,mu,len(loc))
    
    # plot produced hist againts computed density function
    plt.title(f"k={len(loc)}, mu={mu:0.2f}")
    plt.hist(Y,bins=int(events**0.5),density=True)
    plt.plot(t,dist)
    

    k=100 k=10