pythonmachine-learningkerastensorflow2.0tfx

TensorFlow Extended | Trainer Not Warm Starting With GenericExecutor & Keras Model


I'm presently trying to get a Trainer component of a TFX pipeline to warm-start from a previous run of the same pipeline. The use case is:

  1. Run the pipeline once, produce a model.
  2. As new data comes in, train the existing model with the new data.

I am aware the ResolverNode component is designed for this purpose, so you can see how I utilize it below:

# detect the previously trained model
latest_model_resolver = ResolverNode(
  instance_name='latest_model_resolver',
  resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
  latest_model=Channel(type=Model))
context.run(latest_model_resolver)

# set prior model as base_model
train_file = 'tfx_modules/recommender_train.py'
trainer = Trainer(
    module_file=os.path.abspath(train_file),
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    transformed_examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000),
    base_model=latest_model_resolver.outputs['latest_model'])

The components above run successfully, and the ResolverNode is able to detect the latest model from prior pipeline runs. No error is thrown - however, when running context.run(trainer), the model loss basically begins where it started the first time. After the model's first run, it finishes training loss ~0.1, however, upon the second run (with the supposed warm-start), it restarts ~18.2.

This leads me to believe all weights were re-initialized, which I don't believe should have occurred. Below are the relevant model construction functions:

def build_keras_model():
    """build keras model"""
    embedding_max_values = load(open(os.path.abspath('tfx-example/user_artifacts/embedding_max_dict.pkl'), 'rb'))
    embedding_dimensions = dict([(key, 20) for key in embedding_max_values.keys()])
    embedding_pairs = [recommender.EmbeddingPair(embedding_name=feature,
                                                 embedding_dimension=embedding_dimensions[feature],
                                                 embedding_max_val=embedding_max_values[feature])
                       for feature in recommender_constants.univalent_features]

    numeric_inputs = []
    for num_feature in recommender_constants.numeric_features:
        numeric_inputs.append(keras.Input(shape=(1,), name=num_feature))

    input_layers = numeric_inputs + [elem for pair in embedding_pairs for elem in pair.input_layers]
    pre_concat_layers = numeric_inputs + [elem for pair in embedding_pairs for elem in pair.embedding_layers]

    concat = keras.layers.Concatenate()(pre_concat_layers) if len(pre_concat_layers) > 1 else pre_concat_layers[0]
    layer_1 = keras.layers.Dense(64, activation='relu', name='layer1')(concat)
    output = keras.layers.Dense(1, kernel_initializer='lecun_uniform', name='out')(layer_1)
    model = keras.models.Model(input_layers, outputs=output)
    model.compile(optimizer='adam', loss='mean_squared_error')
    return model

def run_fn(fn_args: TrainerFnArgs):
    """function for the Trainer component"""
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor,
                              tf_transform_output, 40)
    eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor,
                             tf_transform_output, 40)

    model = build_keras_model()
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=fn_args.model_run_dir, update_freq='epoch', histogram_freq=1,
        write_images=True)
    model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback],
              epochs=5)

    signatures = {
        'serving_default':
            _get_serve_tf_examples_fn(model, tf_transform_output).get_concrete_function(tf.TensorSpec(
                    shape=[None],
                    dtype=tf.string,
                    name='examples')
            )
    }
    model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

To research the problem, I have perused:

Warm Start Example From TFX https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_warmstart.py

However, this guide uses the Estimator component instead of the Keras components. That component has a warm_start_from initialization parameter which I couldn't find for the Keras equivalent.

I suspect:

  1. Either the warm-start functionality is only available for Estimator components and won't take effect even if base_model is set for Keras components.

  2. I am somehow telling the model to re-initialize weights even after successfully loading the prior model - in that case I would love a pointer as to where that's happening.

Any assistance would be great! Much thanks.


Solution

  • With Keras models you have to load the model first using the base model path, then you can continue training from there instead of building a new model.

    Your Trainer component looks correct, but in run_fn do the following instead:

    def run_fn(fn_args: FnArgs):
      model = tf.keras.models.load_model(fn_args.base_model)
      model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset,
                  validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback],
                  epochs=5)