Applying the tensorflow tutorial on how to implement a transformer model I had some doubts on the training process.
The train_step function is implemented as:
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
with tf.GradientTape() as tape:
predictions, _ = transformer([inp, tar_inp],
training = True)
loss = loss_function(tar_real, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(accuracy_function(tar_real, predictions))
We can see that tf.GradientTape()
is defined as tape in a with statment. That works, but I don't understand how tape can be called outside the statement with:
gradients = tape.gradient(loss, transformer.trainable_variables)
Shouldn't tape be closed at the end of the with statement?
I implemented the code from the tutorial and it works as is.
Most context managers used in with
statements such as with <ContextManager> as <Var>
are written to delete <Var>
at the end of the block, but this is not actually a requirement.
The original proposal for adding the with
statement in python 2.5 shows that at the end of the block, the __exit__
method is called from the context manager, which in this case is tf.GradientTape().__exit__
. If this method does not explicitly delete <Var>
, then it will persist.
In tensorflow, it appears the developers made the decision to not delete this variable so that users can take advantage of the with
keyword and also use the resulting tape
object later on.
Additional information: PEP 343 – The “with” Statement