tensorflowtfx

Why tfx.components.FnArgs does not have an "epochs" attribute?


tfx.components.FnArgs is the way to pass values to the run_fn function which will in turn train the model in Tensorflow Extended pipeline.

Looking at the tfx.components.FnArgs' documentation, I cannot help but wonder why there's no attribute for the number of epochs to run the training loop (perhaps the most important attribute in training). Is this an oversight or am I supposed to control the number of epochs differently?


Solution

  • You can pass the epochs attribute in custom_config dict as shown in example notebook.

    Example code:

    trainer = tfx.components.Trainer(
        module_file=os.path.abspath(_trainer_module_file),
        examples=ratings_transform.outputs['transformed_examples'],
        transform_graph=ratings_transform.outputs['transform_graph'],
        schema=ratings_transform.outputs['post_transform_schema'],
        train_args=tfx.proto.TrainArgs(num_steps=500),
        eval_args=tfx.proto.EvalArgs(num_steps=10),
        custom_config={
            'epochs':5,
            'movies':movies_transform.outputs['transformed_examples'],
            'movie_schema':movies_transform.outputs['post_transform_schema'],
            'ratings':ratings_transform.outputs['transformed_examples'],
            'ratings_schema':ratings_transform.outputs['post_transform_schema']
            })
    
    context.run(trainer, enable_cache=False)