My model is an encoder that has input Z
and output x
.
I'm trying to use a total_loss
that has both traditional supervised learning and regularization term(s). I have additional functions (outside the network) that use the input Z
and the predicted output x_pred
to calculate their respective regularization terms to include in the loss calculation.
# Custom training function within model class
def train_step(self, Z, x):
# Define loss object
loss_object = tf.keras.losses.MeanSquaredError()
with tf.GradientTape() as tape:
# Get encoder output
x_pred = self.encoder(Z)
# Calculate traditional supervised learning data loss
data_loss = loss_object(x, x_pred)
# Calculate regularization terms
x_hat, Z_pred = calc_reg_terms(x_pred, Z) # physics-informed function
# Calculate respective regularization losses
loss_x = loss_object(x, x_hat)
loss_z = loss_object(Z, Z_pred)
"""<Additional Code>"""
What is the correct method for calculating the gradient of my total_loss
?
In the past, I've tried simply adding all the loss terms together, then taking the gradient of the summed loss.
### PAST METHOD ###
# Calculate total loss
total_loss = data_loss + a * loss_x + b * loss_z # a,b -> set hyperparameters
# Get gradients
grads = tape.gradient(total_loss, self.trainable_weights)
However, since my loss_x
and loss_z
are defined outside the encoder, I fear that these losses act more as a bias to the total_loss
calculation because the model is actually performing worse when these losses are added to data_loss
. The data_loss
term has a clear connection to the trainable weights of the encoder, making for a clear gradient calculation, but the same cannot easily be said for my regularization loss terms.
NOTE: Tracking each of these three losses during training shows that data_loss
can decrease with each passing training epoch, but both loss_x
and loss_z
tend to plateau early on during training, hence the fear they act more as an unwanted bias to the total_loss
.
What is the proper way to then calculate the gradients with the data_loss
, loss_x
, and loss_z
terms?
Thanks for the clarification in your comment, it makes sense.
Your code looks correct to me -- that is the general approach. Calculate total_loss = data_reconstruction_loss + constant * regularization_loss
, then calculate the gradient on the total_loss
, and backpropagate. A simple way to make sure that it's working without doing a full hyperparameter sweep is to set a=0
and b=0
, then gradually increase a
from some very small value (e.g., a=1E-10
) to a large value (e.g., a=1
). You can take big steps, but you should see your train and validation loss change as you sweep across values of a
. You can then repeat the same process with b
. If everything works out, continue to the hyperparameter sweep.