machine-learningdeep-learningloss-functiontensorflow.jsautoencoder

LogVar layer of a VAE only returns zeros


I'm building a Variational auto encoder (VAE) with tfjs. For now I'm only exploring with the fashionMNIST dataset and a simple model as follows:

input layer (28*28*1)
flatten
intermediate_1 (dense 50 units - relu)
mean (dense 10 units - relu) // logVar (dense 10 units relu)
SampleLayer (10 units)
intermediate_2 (dense 50 units relu)
reconstructed (dense 784 units - sigmoid)
reshape (28*28*1) => loss=MSE

I created a custom sampling layer extending tf.layers.Layer as below. It's different from other examples that I could find online because I'm using td.addLoss() to add the KL loss function in the layer itself.

class wb_sampling extends tf.layers.Layer {
    constructor(config) {
        super(config);
        this.KL_weight = config.KL_weight; // weights the KL_loss compared to reconstruction loss. If KL_weiht==0, reconstruction loss is the only loss used
        if (this.KL_weight === undefined) {
            this.KL_weight=0.0001; // default
        }
        this.last_mu;
        this.last_logVar;
        
        // Adds KL loss 
        this.addLoss(() => {
            const retour = tf.tidy(() => {
                let kl_loss;
                let z_log_var=this.last_logVar;
                let z_mean=this.last_mu;
                kl_loss = tf.scalar(1).add(z_log_var).sub(z_mean.square()).sub(z_log_var.exp());
                kl_loss = tf.sum(kl_loss, -1);
                kl_loss = kl_loss.mul(tf.scalar(-0.5 * this.KL_weight));
                return (tf.mean(kl_loss));
            }); 
            return (retour);

            
        }); // end of addLoss
    } // end of constructor

    computeOutputShape(inputShape) {
        return inputShape[0]; // same shape as mu
    }

    call(inputs, training) {
        return tf.tidy(() => {
            const [mu, logVar] = inputs;
            
            // store mu and logVar values to be used by the KL loss function
            this.last_mu=mu; // zMean
            this.last_logVar=logVar; // zLogVar
            
            const z = tf.tidy(() => {
                const batch = mu.shape[0];
                const dim = mu.shape[1];
                const epsilon = tf.randomNormal([batch, dim]);
                const half = tf.scalar(0.5);
                const temp = logVar.mul(half).exp().mul(epsilon);
                const sample = mu.add(temp);
                return sample;
            });
            return z;
        });
    } // end of call()

    static get className() {
        return 'wb_sampling';
    }
} // end of wb_sampling layer

The model works well, and the reconstruction is correct, but something seems strange to me : the output tensor of the logVar layer contains only zeros.

I tried with more or less units in the mean / logVar / sample layers (from 2 to 20). I tried to change the KL_weight parameter(that weights the KL loss against the reconstruction loss). I even tried with a KL_weight of 0 (which means that the KL loss is totally disregarded and the model only changes due to the reconstruction loss MSE). Whatever, the output of the logVar layer includes only zeros (note : in the beginning of the training, there are different values, but after a few steps of training, only zeros remain in the output).

The mean layer however outputs varied values, and I noticed that they are smaller when KL_weight > 0 than when KL_weight==0 so the KL loss function seems to be working.

Could this be normal? maybe outputs higher than zero in the logVar layer wouldn't improve the reconstruction in a task as simple as this one?

Have you experienced such all-zeros outputs in your logVar outputs? If not, what are the usual values you have there? And do you have any idea of what may cause the problem?


Solution

  • This is due to the relu activations in the encoder output. To understand why this is an issue -- the KL-divergence loss pulls the variance output to 1. At the same time, the reconstruction loss generally pulls it to 0, since less variance makes the codes more reliable, leading to better reconstructions.

    Thus, variances can be expected to be between 0 and 1. That corresponds to log variances of 0 (equal to variance of 1) or below (variance smaller than 1 -> negative log variance). But with a relu output, you cannot get values less than 0 for the log variance, and a value of exactly 0 is the best the model can do.

    Similarly, relu for the means also doesn't really make sense, as there is no reason why the mean values shouldn't be negative.