I am trying to verify whether a custom training loop changes the Keras Model's weights. My current method is to deepcopy
the model.trainable_weights
list before training and then compare that to model.trainable_weights
after training. Is this a valid way to make this comparison? The results of my method indicate that the weights do in fact change (which is the expected result anyway since the loss clearly decreases per epoch), but I just want to verify that what I am doing is valid. Below is code from the slightly adapted Keras custom training loop tutorial plus the code I use to compare changes in weights before/after model training:
# Imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from copy import deepcopy
# The model
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
##########################
# WEIGHTS BEFORE TRAINING
##########################
# I use deepcopy here to avoid mutating the weights list during training
weights_before_training = deepcopy(model.trainable_weights)
##########################
# Keras Tutorial
##########################
# Load data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reduce the size of the data to speed up training
x_train = x_train[:128]
x_test = x_test[:128]
y_train = y_train[:128]
y_test = y_test[:128]
# Make tf dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=64).batch(16)
# The training loop
print('Begin Training')
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
epochs = 2
for epoch in range(epochs):
# Logging start of epoch
print("\nStart of epoch %d" % (epoch,))
# Save loss values for logging
loss_values = []
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True) # Logits for this minibatch
loss_value = loss_fn(y_batch_train, logits)
# Append to list for logging
loss_values.append(loss_value)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
print('Epoch Loss:', np.mean(loss_values))
print('End Training')
##########################
# WEIGHTS AFTER TRAINING
##########################
weights_after_training = model.trainable_weights
# Note: `trainable_weights` is a list of kernel and bias tensors.
print()
print('Begin Trainable Weights Comparison')
for i in range(len(weights_before_training)):
print(f'Trainable Tensors for Element {i + 1} of List Are Equal:', tf.reduce_all(tf.equal(weights_before_training[i], weights_after_training[i])).numpy())
print('End Trainable Weights Comparison')
>>> Begin Training
>>> Start of epoch 0
>>> Epoch Loss: 44.66055
>>>
>>> Start of epoch 1
>>> Epoch Loss: 5.306543
>>> End Training
>>>
>>> Begin Trainable Weights Comparison
>>> Trainable Tensors for Element 1 of List Are Equal : False
>>> Trainable Tensors for Element 2 of List Are Equal : False
>>> Trainable Tensors for Element 3 of List Are Equal : False
>>> Trainable Tensors for Element 4 of List Are Equal : False
>>> Trainable Tensors for Element 5 of List Are Equal : False
>>> Trainable Tensors for Element 6 of List Are Equal : False
>>> End Trainable Weights Comparison
Summarizing from the comments and adding some more information, for the benefit of the community:
The method followed in the above code, i.e., comparing deepcopy(model.trainable_weights)
before Training
with model.trainable_weights
, after the Model
is Trained using Custom Training Loop
, is the Right Approach.
In addition to that, if we don't want the Model to be Trained, we can Freeze all the Layers
of the Model
using the code,
model.trainable = false
.