pythontensorflowtensorflow-estimator

Create Estimator from checkpoint and save as SavedModel without further training


I have created an Estimator from a TF Slim Resnet V2 checkpoint and tested it to make predictions. The main thing of what I did is basically similar to a normal Estimator together with assign_from_checkpoint_fn:

def model_fn(features, labels, mode, params):
  ...
  slim.assign_from_checkpoint_fn(os.path.join(checkpoint_dir, 'resnet_v2_50.ckpt'), slim.get_model_variables('resnet_v2')
  ...
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
      'class_ids': predicted_classes[:, tf.newaxis],
      'probabilities': tf.nn.softmax(logits),
      'logits': logits,
    }
  return tf.estimator.EstimatorSpec(mode, predictions=predictions)

To export the estimator as a SavedModel, I made a serving_input_fn as follows:

def image_preprocess(image_buffer):
    image = tf.image.decode_jpeg(image_buffer, channels=3)
    image_preprocessing_fn = preprocessing_factory.get_preprocessing('inception', is_training=False)
    image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
    return image

def serving_input_fn():
    input_ph = tf.placeholder(tf.string, shape=[None], name='image_binary')
    image_tensors = image_preprocess(input_ph)
    return tf.estimator.export.ServingInputReceiver(image_tensors, input_ph)

In the main function, I use export_saved_model to try to export Estimator to SavedModel format:

def main():
    ...
    classifier = tf.estimator.Estimator(model_fn=model_fn)
    classifier.export_saved_model(dir_path, serving_input_fn)

However, when I try to run the codes, it says "Couldn't find trained model at /tmp/tmpn3spty2z". From what I understand, this export_saved_model tries to find a trained Estimator model to export to SavedModel. However, I would like to know if there are any ways I can restore the pretrained checkpoint into an Estimator and export the Estimator to a SavedModel without any further training?


Solution

  • I have solved my problem. To export TF Slim Resnet checkpoint with TF 1.14 to SavedModel, warm start can be used together with export_savedmodel as follows:

    config = tf.estimator.RunConfig(save_summary_steps = None, save_checkpoints_secs = None)
    warm_start = tf.estimator.WarmStartSettings(checkpoint_dir, checkpoint_name)
    classifier = tf.estimator.Estimator(model_fn=model_fn, warm_start_from = warm_start, config = config)
    classifier.export_savedmodel(export_dir_base = FLAGS.output_dir, serving_input_receiver_fn =  serving_input_fn)