kerasearly-stopping

How to pick one val loss value for Earlystopping , when multiple losses are returned


I am training a Variational Autoencoder with custom losses, which has three components

  1. total_loss
  2. reconstruction_loss
  3. kl_loss

The code is too big, I hope the snippet articulates my problem (The running code is here in colab). With these two functions I am able to train my model and at the end of every epoch I am able to get the values for all three losses.

def train_step(self, data):
    if isinstance(data, tuple):
        data = data[0]
    with tf.GradientTape() as tape:
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(tf.square(data - reconstruction),
                                             axis = [1,2,3])
        reconstruction_loss *= self.r_loss_factor
        kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        kl_loss = tf.reduce_sum(kl_loss, axis = 1)
        kl_loss *= -0.5
        total_loss = reconstruction_loss + kl_loss
    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    return {
        "loss": total_loss,
        "reconstruction_loss": reconstruction_loss,
        "kl_loss": kl_loss,
    }

def test_step(self, input_data):
    validation_data = input_data[0] # <-- Seperate X and y
    z_mean, z_log_var, z = self.encoder(validation_data)
    val_reconstruction = self.decoder(z)
    val_reconstruction_loss = tf.reduce_mean(tf.square(validation_data - val_reconstruction),
                                 axis = [1,2,3])
    val_reconstruction_loss *= self.r_loss_factor

    val_kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    val_kl_loss = tf.reduce_sum(val_kl_loss, axis = 1)
    val_kl_loss *= -0.5
    val_total_loss = val_reconstruction_loss + val_kl_loss
    return {
        "loss": val_total_loss,
        "reconstruction_loss": val_reconstruction_loss,
        "kl_loss": val_kl_loss,
    }

However, if I insert a earlystopper in the call backs like this:

early_stopper = EarlyStopping(monitor='val_loss', min_delta=0.001, patience=10)

I get an error at epoch end

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

My hunch is , the earlystopper is getting 3 values to judge from. How can I select only the total val loss for judging early stopping condition


Solution

  • After much googling , I finally found the answer I was looking for here. It turns out , I had to make the following modifications.
    Add these attributes to the VAEModel class ( colab notebook here )

    self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
    self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
                name="reconstruction_loss"
            )
    self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
    
    self.val_total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
    self.val_reconstruction_loss_tracker = tf.keras.metrics.Mean(
                name="reconstruction_loss"
            )
    self.val_kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
    

    Add two properties ( 1 each for train and val )

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]
    
    @property
    def val_metrics(self):
        return [
            self.val_total_loss_tracker,
            self.val_reconstruction_loss_tracker,
            self.val_kl_loss_tracker,
        ]
    

    And finally in the train and test steps I was able to access the losses by adding these lines

    # Add to the train step    
    self.total_loss_tracker.update_state(total_loss)
    self.reconstruction_loss_tracker.update_state(reconstruction_loss)
    self.kl_loss_tracker.update_state(kl_loss)
    
    return {
      "loss": self.total_loss_tracker.result(),
      "reconstruction_loss": self.reconstruction_loss_tracker.result(),
      "kl_loss": self.kl_loss_tracker.result(),
            } 
    
    
    # Add to the test step
    self.val_total_loss_tracker.update_state(val_total_loss)
    self.val_reconstruction_loss_tracker.update_state(val_reconstruction_loss)
    self.val_kl_loss_tracker.update_state(val_kl_loss)
    
        
    return {
        "loss": self.val_total_loss_tracker.result(),
        "reconstruction_loss": self.val_reconstruction_loss_tracker.result(),
        "kl_loss": self.val_kl_loss_tracker.result(),
        }
    

    The full working version is available in colab here. Now the early stopping is able to access the losses as a single values. In all honesty , it'd take me a while to understand , why its working. Nevertheless , my code is at least functional.