pythontensorflowautoencoder

Using a custom parameter in train_step() method of a VAE, which is different for each epoch


I am trying to build a VAE, similar to the example here: https://keras.io/examples/generative/vae/ However, I am employing a custom loss in the model and I would like this loss to be multiplied by a factor that increases with progressing epochs. I did it by defining at the beginning

def __init__(self, encoder, decoder, **kwargs):
    self.eloss_weight = tf.Variable(initial_value=args.eloss_weight, trainable=False)

compiled the model to run eagerly

vae.compile(optimizer=tf.keras.optimizers.Adam(jit_compile=False), run_eagerly=True)

and when fitting I change the value of the energy loss weight using a callback like this

def eloss_weight_increase(epoch, logs):
    vae.eloss_weight = vae.eloss_weight + 1


increase_eloss_weight = tf.keras.callbacks.LambdaCallback(on_epoch_end=eloss_weight_increase)

vae.fit(
    X_train, V_train, batch_size=args.batch_size, epochs=args.epochs,
    callbacks=[increase_eloss_weight], verbose=1,)

I had to use run_eagerly, because if I did not, in train_step, when the parameter was called as self.eloss_weight, it did not look if it was changed or not (I guess in order optimize the model to be as fast as possible, when test_step was compiled, it just remembered the exact value and stopped thinking of it as a variable). However, because of run_eagerly it now runs 4 times slower (which is pretty bad, before this, the model took 4.5 hours to train, now it is almost a day). Is there any way to do this thing without using run_eagerly=True?

Thank you very much.


Solution

  • vae.eloss_weight = vae.eloss_weight + 1 overwrites a Variable by a Tensor (not the same). You should use a variable's assign method to keep the Variable object intact:

    vae.eloss_weight.assign(vae.eloss_weight + 1.)
    

    or

    vae.eloss_weight.assign_add(1.)
    

    I have used this myself before, so it should work (also without eager execution).