pythontensorflowwith-statementtransformer-modelgradienttape

Transformer tutorial with tensorflow: GradientTape outside the with statment but still working


Applying the tensorflow tutorial on how to implement a transformer model I had some doubts on the training process.

The train_step function is implemented as:

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  with tf.GradientTape() as tape:
    predictions, _ = transformer([inp, tar_inp],
                                 training = True)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(accuracy_function(tar_real, predictions))

We can see that tf.GradientTape() is defined as tape in a with statment. That works, but I don't understand how tape can be called outside the statement with:

gradients = tape.gradient(loss, transformer.trainable_variables)

Shouldn't tape be closed at the end of the with statement?

I implemented the code from the tutorial and it works as is.


Solution

  • Most context managers used in with statements such as with <ContextManager> as <Var> are written to delete <Var> at the end of the block, but this is not actually a requirement.

    The original proposal for adding the with statement in python 2.5 shows that at the end of the block, the __exit__ method is called from the context manager, which in this case is tf.GradientTape().__exit__. If this method does not explicitly delete <Var>, then it will persist.

    In tensorflow, it appears the developers made the decision to not delete this variable so that users can take advantage of the with keyword and also use the resulting tape object later on.

    Additional information: PEP 343 – The “with” Statement