statisticsjulianormal-distributionmatrix-factorization

In Distributions.jl package for Julia, how to define MvNormal distributions with the Cholesky matrix?


I am writing some code that transforms multi-variate distributions by operating mostly on the cholesky factor of the covariance matrix. Suppose I have

using Distributions
g1 = MvNormal([1,2], [2 1; 1 2])

The naive way to do it is

c1 = cholesky(cov(g1)).L
# work on matrix c1 in some way, for this example I'll just directly assign it to c2 
c2 = c1
s2 = c2*c2'
g2 = MvNormal([1,2], s2)

But this means repeatedly calculating the Cholesky factor of a covariance matrix and then obtaining the corresponding covariance matrix, when for reasons of speed and numerical stability I would prefer to always use Cholesky factors, as I rarely need the covariance matrix. Note that even with this very simple example, some error can already be seen:

g2 = MvNormal([1,2], s2)
FullNormal(
dim: 2
μ: [1.0, 2.0]
Σ: [2.0000000000000004 1.0; 1.0 1.9999999999999996]
)

What is the correct Julia way to do this? I've seen, for example, this answer to a GitHub issue which suggests it's possible to specify only the Cholesky factor for MvNormal, but it assumes some familiarity with Julia which I don't have, and so it's not enough for me to figure out how to actually do it. Also it's quite old, and the surrounding discussion suggests the way this works has been changed in the meantime.

So, to summarize, if I have the cholesky factor of a covariance matrix, how can I efficiently define a MvNormal such that I can get that factor back efficiently, and also still have access to the methods defined for MvNormal?

Note: the transformation of c1 to c2 in my application is an implementation of the square root form of the unscented transform. I realise there are probably libraries that do this, but I would prefer to have my own implementation since I plan to tinker with it.


Solution

  • Using PDMats package I think you can get the functionality you need:

    julia> using Distributions, LinearAlgebra, PDMats
    
    julia> R = rand(5,5);
    
    julia> M = R*transpose(R);
    
    julia> C = cholesky(M)
    Cholesky{Float64, Matrix{Float64}}
    U factor:
    5×5 UpperTriangular{Float64, Matrix{Float64}}:
     1.44058  1.09028   0.971589  1.09389   1.36865
      ⋅       0.939203  0.783269  0.721269  0.0734465
      ⋅        ⋅        0.576735  0.286545  0.614992
      ⋅        ⋅         ⋅        0.892629  0.662157
      ⋅        ⋅         ⋅         ⋅        0.256624
    
    julia> N = PDMat(C)
    5×5 PDMat{Float64, Matrix{Float64}}:
     2.07526  1.57064  1.39965  1.57583  1.97165
     1.57064  2.07082  1.79496  1.87007  1.5612
     1.39965  1.79496  1.89012  1.79302  1.74198
     1.57583  1.87007  1.79302  2.59572  2.31741
     1.97165  1.5612   1.74198  2.31741  2.76113
    
    julia> N ≈ M
    true
    
    julia> mvn = MvNormal(zeros(5), N);
    
    julia> rand(mvn, 2)
    5×2 Matrix{Float64}:
     -3.64747   -1.30948
     -1.84731   -1.00835
     -1.68674   -1.01449
     -0.818989   2.4971
     -2.87103    1.37066
    

    So now, mvn is supposadly the MvNormal with C as the Cholesky factor. And any new Cholesky factor can be used to generate an MvNormal in this way, by constructing a PDMat from a Cholesky factor object which can be constructed from a lower-triangular matrix.

    Hopefully, the overhead is less than generating the new MvNormal covariance matrix each time.

    For example:

    julia> L = LowerTriangular(abs.(rand(3,3)))
    3×3 LowerTriangular{Float64, Matrix{Float64}}:
     0.384057   ⋅         ⋅ 
     0.447769  0.719279   ⋅ 
     0.350396  0.474858  0.608242
    
    julia> CN = Cholesky(L)
    Cholesky{Float64, Matrix{Float64}}
    L factor:
    3×3 LowerTriangular{Float64, Matrix{Float64}}:
     0.384057   ⋅         ⋅ 
     0.447769  0.719279   ⋅ 
     0.350396  0.474858  0.608242
    
    julia> L_mvn = MvNormal(zeros(3), PDMat(CN))
    FullNormal(
    dim: 3
    μ: [0.0, 0.0, 0.0]
    Σ: [0.14750000546843048 0.17196880546415463 0.1345721043936468; 0.17196880546415463 0.7178585036224733 0.498451513851043; 0.1345721043936468 0.498451513851043 0.71822573422471]
    )