pythonmachine-learningmeanvarianceonline-algorithm

Efficient algorithm for online Variance update over batched data


I have a large amount of multi-demensional data and want to calculate the variance of an axis across all of them. Memory wise I cannot create a large array to calculate the variance in one step. I therefore need to load the data in batches and need to update the current variance somehow in an online way after each batch.

toy example

In the end the the batch wise updated online_var should match correct_var. However, I struggle to find an efficient algorithm for this.

import numpy as np
np.random.seed(0)
# Correct calculation of the variance
all_data = np.random.randint(0, 9, (9, 3))  # <-- does not fit into memory
correct_var = all_data.var(axis=0)
# Create batches
batches = all_data.reshape(-1, 3, 3)

online_var = 0
for batch in batches:
   batch_var = batch.var(axis=0)
   online_var = ?  # how to update this correctly
assert np.allclose(correct_var, online_var)

I found the Welford's online algorithm, however it is very slow as it only updates the variance for a single new value, i.e. it cannot process a whole batch at once. As I am working with images an update is necessary for each pixel and each channel.


How can I update the variance for multiple new observations in an efficient way that considers the whole batch at once?


Solution

  • These are the two functions needed to update/combine the mean and variances of two batches. Both functions can be used with vectors (the color channels). The mean and variance can be acquired from inbuilt methods like batch.var().

    Equations taken from: https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html

    # m amount of samples (or pixels) over all previous batches
    # n amount of samples in new incoming batch
    # mu1 previous mean
    # mu2 mean of current batch
    # v1 previous variance
    # v2 variance of current batch
    
    def combine_means(mu1, mu2, m, n):
        """
        Updates old mean mu1 from m samples with mean mu2 of n samples.
        Returns the mean of the m+n samples.
        """
        return (m/(m+n))*mu1 + (n/(m+n))*mu2
    
    def combine_vars(v1, v2, mu1, mu2, m, n):
        """
        Updates old variance v1 from m samples with variance v2 of n samples.
        Returns the variance of the m+n samples.
        """
        return (m/(m+n))*v1 + n/(m+n)*v2 + m*n/(m+n)**2 * (mu1 - mu2)**2
        
    

    As you see one can simplify them a bit by reusing some calculations like m+n but keeping it in this pure form for better understanding.