I'm presently trying to get a Trainer component of a TFX pipeline to warm-start from a previous run of the same pipeline. The use case is:
I am aware the ResolverNode
component is designed for this purpose, so you can see how I utilize it below:
# detect the previously trained model
latest_model_resolver = ResolverNode(
instance_name='latest_model_resolver',
resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
latest_model=Channel(type=Model))
context.run(latest_model_resolver)
# set prior model as base_model
train_file = 'tfx_modules/recommender_train.py'
trainer = Trainer(
module_file=os.path.abspath(train_file),
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
transformed_examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000),
base_model=latest_model_resolver.outputs['latest_model'])
The components above run successfully, and the ResolverNode
is able to detect the latest model from prior pipeline runs. No error is thrown - however, when running context.run(trainer)
, the model loss basically begins where it started the first time. After the model's first run, it finishes training loss ~0.1, however, upon the second run (with the supposed warm-start), it restarts ~18.2.
This leads me to believe all weights were re-initialized, which I don't believe should have occurred. Below are the relevant model construction functions:
def build_keras_model():
"""build keras model"""
embedding_max_values = load(open(os.path.abspath('tfx-example/user_artifacts/embedding_max_dict.pkl'), 'rb'))
embedding_dimensions = dict([(key, 20) for key in embedding_max_values.keys()])
embedding_pairs = [recommender.EmbeddingPair(embedding_name=feature,
embedding_dimension=embedding_dimensions[feature],
embedding_max_val=embedding_max_values[feature])
for feature in recommender_constants.univalent_features]
numeric_inputs = []
for num_feature in recommender_constants.numeric_features:
numeric_inputs.append(keras.Input(shape=(1,), name=num_feature))
input_layers = numeric_inputs + [elem for pair in embedding_pairs for elem in pair.input_layers]
pre_concat_layers = numeric_inputs + [elem for pair in embedding_pairs for elem in pair.embedding_layers]
concat = keras.layers.Concatenate()(pre_concat_layers) if len(pre_concat_layers) > 1 else pre_concat_layers[0]
layer_1 = keras.layers.Dense(64, activation='relu', name='layer1')(concat)
output = keras.layers.Dense(1, kernel_initializer='lecun_uniform', name='out')(layer_1)
model = keras.models.Model(input_layers, outputs=output)
model.compile(optimizer='adam', loss='mean_squared_error')
return model
def run_fn(fn_args: TrainerFnArgs):
"""function for the Trainer component"""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor,
tf_transform_output, 40)
eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor,
tf_transform_output, 40)
model = build_keras_model()
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=fn_args.model_run_dir, update_freq='epoch', histogram_freq=1,
write_images=True)
model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset,
validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback],
epochs=5)
signatures = {
'serving_default':
_get_serve_tf_examples_fn(model, tf_transform_output).get_concrete_function(tf.TensorSpec(
shape=[None],
dtype=tf.string,
name='examples')
)
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
To research the problem, I have perused:
Warm Start Example From TFX https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_warmstart.py
However, this guide uses the Estimator
component instead of the Keras components. That component has a warm_start_from
initialization parameter which I couldn't find for the Keras equivalent.
I suspect:
Either the warm-start functionality is only available for Estimator
components and won't take effect even if base_model
is set for Keras components.
I am somehow telling the model to re-initialize weights even after successfully loading the prior model - in that case I would love a pointer as to where that's happening.
Any assistance would be great! Much thanks.
With Keras models you have to load the model first using the base model path, then you can continue training from there instead of building a new model.
Your Trainer component looks correct, but in run_fn
do the following instead:
def run_fn(fn_args: FnArgs):
model = tf.keras.models.load_model(fn_args.base_model)
model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset,
validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback],
epochs=5)