pythontensorflowkerasgradienttape

If-Else Statement in Custom Training Loop in Tensorflow


I created a model class which is a subclass of keras.Model. While training the model, I want to change the weights of the loss functions after some epochs. In order to do that I created boolean variables to my model indicating that the model should start training with additional loss function. I add a pseudo code that mainly shows what I am trying to achieve.

class MyModel(keras.Model):
    self.start_loss_2 = False


def train_step(self):
    # Check if training with loss_2 started 
    weight_loss_2 = 0.0
    if self.start_loss_2:
        weight_loss_2 = 0.5

    # Pass the data through model
    # Calculate two loss values
    total_loss = loss_1 + weight_loss_2 * loss_2
    # Calculate gradients with tf.Tape
    # Update variables


# This is called via Callback after each epoch
def epoch_finised(epoch_num):
    if epoch_num > START_LOSS_2:
        self.start_loss_2 = True


My questions is:


Solution

  • Yes. You can create a tf.Variable and then assign a new value to it based on some training criteria.

    Example:

    import numpy as np
    import tensorflow as tf
    
    
    # simple toy network
    x_in = tf.keras.Input((10))
    x = tf.keras.layers.Dense(25)(x_in)
    x_out = tf.keras.layers.Dense(1)(x)
    
    # model
    m = tf.keras.Model(x_in, x_out)
    
    # fake data
    X = tf.random.normal((100, 10))
    y0 = tf.random.normal((100, ))
    y1 = tf.random.normal((100, ))
    
    # optimizer
    m_opt = tf.keras.optimizers.Adam(1e-2)
    
    # prep data
    ds = tf.data.Dataset.from_tensor_slices((X, y0, y1))
    ds = ds.repeat().batch(5)
    train_iter = iter(ds)
    
    # toy loss function that uses a weight
    def loss_fn(y_true0, y_true1, y_pred, weight):
        mse = tf.keras.losses.MSE
        mse_0 = tf.math.reduce_mean(mse(y_true0, y_pred))
        mse_1 = tf.math.reduce_mean(mse(y_true1, y_pred))
        return mse_0 + weight * mse_1
      
    NUM_EPOCHS = 4
    NUM_BATCHES_PER_EPOCH = 10
    START_NEW_LOSS_AT_GLOBAL_STEP = 20
    
    # the weight variable set to 0 initially and then
    # will be changed after a certain number of steps
    # (or some other training criteria)
    w = tf.Variable(0.0, trainable=False)
    
    for epoch in range(NUM_EPOCHS):
        losses = []
        for batch in range(NUM_BATCHES_PER_EPOCH):
            X_train, y0_train, y1_train = next(train_iter)
            with tf.GradientTape() as tape:
                y_hat = m(X_train)
                loss = loss_fn(y0_train, y1_train, y_hat, w)
                losses.append(loss)
        
            m_vars = m.trainable_variables
            m_grads = tape.gradient(loss, m_vars)
            m_opt.apply_gradients(zip(m_grads, m_vars))
        
        print(f"epoch: {epoch}\tloss: {np.mean(losses):.4f}")
        losses = []
    
        # if the criteria is met assign a huge number to see if the
        # loss spikes up
        if (epoch + 1) * (batch + 1) >= START_NEW_LOSS_AT_GLOBAL_STEP:
            w.assign(10000.0)
    
    # epoch: 0  loss: 1.8226
    # epoch: 1  loss: 1.1143
    # epoch: 2  loss: 8788.2227    <= looks like assign worked
    # epoch: 3  loss: 10999.5449