pythonmeanupdatesvarianceonline-algorithm

Efficient algorithm for online Variance over image batches


I have a large amount of images and want to calculate the variance (of each channel) across all of them. I am having the problem of finding an efficient (and even correct) algorithm for this.

I found the Welford's online algorithm, but it is way too slow as it does not vectorize across a single image or a batch of images.

How to improve the speed of it by using vectorization or making use of inbuilt variance algorithms?


Solution

  • These are the two functions needed to update/combine the mean and variances of two batches. Both functions can be used with vectors (the color channels). The mean and variance can be acquired from inbuilt methods like batch.var().

    Equations taken from: https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html

    # m amount of samples (or pixels) over all previous badges
    # n amount of samples in new incoming batch
    # mu1 previous mean
    # mu2 mean of current batch
    # v1 previous variance
    # v2 variance of current batch
    
    def combine_means(mu1, mu2, m, n):
        """
        Updates old mean mu1 from m samples with mean mu2 of n samples.
        Returns the mean of the m+n samples.
        """
        return (m/(m+n))*mu1 + (n/(m+n))*mu2
    
    def combine_vars(v1, v2, mu1, mu2, m, n):
        """
        Updates old variance v1 from m samples with variance v2 of n samples.
        Returns the variance of the m+n samples.
        """
        return (m/(m+n))*v1 + n/(m+n)*v2 + m*n/(m+n)**2 * (mu1 - mu2)**2
        
    

    As you see one can simplify them a bit by reusing some calculations like m+n but keeping it in this pure form for better understanding.