tfx.components.FnArgs
is the way to pass values to the run_fn
function which will in turn train the model in Tensorflow Extended pipeline.
Looking at the tfx.components.FnArgs' documentation, I cannot help but wonder why there's no attribute for the number of epochs to run the training loop (perhaps the most important attribute in training). Is this an oversight or am I supposed to control the number of epochs differently?
You can pass the epochs
attribute in custom_config
dict as shown in example notebook.
Example code:
trainer = tfx.components.Trainer(
module_file=os.path.abspath(_trainer_module_file),
examples=ratings_transform.outputs['transformed_examples'],
transform_graph=ratings_transform.outputs['transform_graph'],
schema=ratings_transform.outputs['post_transform_schema'],
train_args=tfx.proto.TrainArgs(num_steps=500),
eval_args=tfx.proto.EvalArgs(num_steps=10),
custom_config={
'epochs':5,
'movies':movies_transform.outputs['transformed_examples'],
'movie_schema':movies_transform.outputs['post_transform_schema'],
'ratings':ratings_transform.outputs['transformed_examples'],
'ratings_schema':ratings_transform.outputs['post_transform_schema']
})
context.run(trainer, enable_cache=False)