tensorflowkerastfxtensorflow-model-analysistensorflow-extended

TFX Tensorflow model validator component - You passed a data dictionary with keys ['image_raw_xf']. Expected the following keys: ['input_1']


I'm building a tfx pipeline based on the cifar10 example : [https://github.com/tensorflow/tfx/tree/master/tfx/examples/cifar10]

The difference is that I don't want to convert it to tf_lite model and instead use a regular keras based tensorflow model.

Everything works as expected until I get to the Evaluator component as it fails with the following error:

ValueError: Missing data for input "input_1". You passed a data dictionary with keys ['image_xf']. Expected the following keys: ['input_1']
 [while running 'Run[Trainer]']

Not sure what I'm doing wrong, but so far I debugged/modified the code as follows:

[1] The preprocessing_fn output is outputting the key image_xf:

_IMAGE_KEY = 'image'
_LABEL_KEY = 'label'

def _transformed_name(key):
  return key + '_xf'

def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}

  # tf.io.decode_png function cannot be applied on a batch of data.
  # We have to use tf.map_fn
  image_features = tf.map_fn(
      lambda x: tf.io.decode_png(x[0], channels=3),
      inputs[_IMAGE_KEY],
      dtype=tf.uint8)
  # image_features = tf.cast(image_features, tf.float32)
  image_features = tf.image.resize(image_features, [224, 224])
  image_features = tf.keras.applications.mobilenet.preprocess_input(
      image_features)

  outputs[_transformed_name(_IMAGE_KEY)] = image_features
  #outputs["input_1"] = image_features
  # TODO(b/157064428): Support label transformation for Keras.
  # Do not apply label transformation as it will result in wrong evaluation.
  outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY]

  return outputs

[2] When I build the model, I am using transfer learning with an inputLayer with the same name image_xf.

def _build_keras_model() -> tf.keras.Model:
  """Creates a Image classification model with MobileNet backbone.

  Returns:
    The image classifcation Keras Model and the backbone MobileNet model
  """
  # We create a MobileNet model with weights pre-trained on ImageNet.
  # We remove the top classification layer of the MobileNet, which was
  # used for classifying ImageNet objects. We will add our own classification
  # layer for CIFAR10 later. We use average pooling at the last convolution
  # layer to get a 1D vector for classifcation, which is consistent with the
  # origin MobileNet setup
  base_model = tf.keras.applications.MobileNet(
      input_shape=(224, 224, 3),
      include_top=False,
      weights='imagenet',
      pooling='avg')
  base_model.input_spec = None

  # We add a Dropout layer at the top of MobileNet backbone we just created to
  # prevent overfiting, and then a Dense layer to classifying CIFAR10 objects
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(
          input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)),
      base_model,
      tf.keras.layers.Dropout(0.1),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

[3] The model signature is created accordingly:


def _get_serve_image_fn(model, tf_transform_output):
  """Returns a function that feeds the input tensor into the model."""
  model.tft_layer = tf_transform_output.transform_features_layer()
  @tf.function
  def serve_image_fn(serialized_tf_examples):
    
    feature_spec = tf_transform_output.raw_feature_spec()
    feature_spec.pop(_LABEL_KEY)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)

    transformed_features = model.tft_layer(parsed_features)
    
    return model(transformed_features)

  return serve_image_fn

def run_fn(fn_args: FnArgs):

  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)


  signatures = {
      'serving_default':
          _get_serve_image_fn(model,tf_transform_output).get_concrete_function(
              tf.TensorSpec(
                  shape=[None],
                  dtype=tf.string,
                  name=_IMAGE_KEY))
  }

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir)
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

Now, I suspect that tensorflow is not saving the model correctly because when I export the saved model, the input layer is input_1 instead of image_xf.

import tensorflow as tf
import numpy as np

import tensorflow.python.ops.numpy_ops.np_config as np_config
np_config.enable_numpy_behavior()

path = './model/Format-Serving/'

imported = tf.saved_model.load(path)

model = tf.keras.models.load_model(path)
print(model.summary())

print(list(imported.signatures.keys()))

print(model.get_layer('mobilenet_1.00_224').layers[0].name)

The thing to notice here is (1) that the Input layer I added in Sequential model above is missing and (2) the mobilenet first layer is input_1, so it makes sense why I'm getting a mismatch.

2021-10-15 08:33:40.683034: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenet_1.00_224 (Function (None, 1024)              3228864   
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                10250     
=================================================================
Total params: 3,239,114
Trainable params: 1,074,186
Non-trainable params: 2,164,928
_________________________________________________________________
None
['serving_default']
input_1

So how can I actually get the model to save correctly with the right input?

Here is the full code:

pipeline.py

# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CIFAR10 image classification example using TFX.

This example demonstrates how to do data augmentation, transfer learning,
and inserting TFLite metadata with TFX.
The trained model can be pluged into MLKit for object detection.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import os
from typing import List, Text

import absl
from tfx import v1 as tfx
import tensorflow_model_analysis as tfma
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import ImportExampleGen
from tfx.components import Pusher
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.dsl.components.common import resolver
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing

_pipeline_name = 'cifar10_native_keras'

# This example assumes that CIFAR10 train set data is stored in
# ~/cifar10/data/train, test set data is stored in ~/cifar10/data/test, and
# the utility function is in ~/cifar10. Feel free to customize as needed.
_cifar10_root = os.path.join(os.getcwd())
_data_root = os.path.join(_cifar10_root, 'data')
# Python module files to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_module_file = os.path.join(_cifar10_root, 'cifar10_utils_native_keras.py')
# Path which can be listened to by the model server. Pusher will output the
# trained model here.
_serving_model_dir_lite = os.path.join(_cifar10_root, 'serving_model_lite',
                                       _pipeline_name)

# Directory and data locations.  This example assumes all of the images,
# example code, and metadata library is relative to $HOME, but you can store
# these files anywhere on your local filesystem.
_tfx_root = os.path.join(os.getcwd(), 'tfx')
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
# Sqlite ML-metadata db path.
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
                              'metadata.db')
# Path to labels file for mapping model outputs.
_labels_path = os.path.join(_data_root, 'labels.txt')


# Pipeline arguments for Beam powered Components.
_beam_pipeline_args = [
    '--direct_running_mode=multi_processing',
    '--direct_num_workers=0', 
]


def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir_lite: Text,
                     metadata_path: Text,
                     labels_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
  """Implements the CIFAR10 image classification pipeline using TFX."""
  # This is needed for datasets with pre-defined splits
  # Change the pattern argument to train_whole/* and test_whole/* to train
  # on the whole CIFAR-10 dataset
  input_config = example_gen_pb2.Input(splits=[
      example_gen_pb2.Input.Split(name='train', pattern='train/*'),
      example_gen_pb2.Input.Split(name='eval', pattern='test/*')
  ])

  # Brings data into the pipeline.
  example_gen = ImportExampleGen(
      input_base=data_root, input_config=input_config)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  model_resolver = resolver.Resolver(
    #instance_name='latest_model_resolver',
    strategy_class=tfx.dsl.experimental.LatestArtifactStrategy,
    model=Channel(type=Model)).with_id('latest_blessed_model_resolver')

  # Uses user-provided Python function that trains a model.
  # When traning on the whole dataset, use 18744 for train steps, 156 for eval
  # steps. 18744 train steps correspond to 24 epochs on the whole train set, and
  # 156 eval steps correspond to 1 epoch on the whole test set. The
  # configuration below is for training on the dataset we provided in the data
  # folder, which has 128 train and 128 test samples. The 160 train steps
  # correspond to 40 epochs on this tiny train set, and 4 eval steps correspond
  # to 1 epoch on this tiny test set.
  trainer = Trainer(
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      base_model=model_resolver.outputs['model'],
      train_args=trainer_pb2.TrainArgs(num_steps=160),
      eval_args=trainer_pb2.EvalArgs(num_steps=4),
      custom_config={'labels_path': labels_path})

  # Get the latest blessed model for model validation.
  # model_resolver = resolver.Resolver(
  #     strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
  #     model=Channel(type=Model),
  #     model_blessing=Channel(
  #         type=ModelBlessing)).with_id('latest_blessed_model_resolver')
 
  # Uses TFMA to compute evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compare to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='label')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.55}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-3})))
          ])
      ])

  # Uses TFMA to compute the evaluation statistics over features of a model.
  # We evaluate using the materialized examples that are output by Transform
  # because
  # 1. the decoding_png function currently performed within Transform are not
  # compatible with TFLite.
  # 2. MLKit requires deserialized (float32) tensor image inputs
  # Note that for deployment, the same logic that is performed within Transform
  # must be reproduced client-side.
  evaluator = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      #baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  # Checks whether the model passed the validation steps and pushes the model
  # to a file destination if check passed.
  pusher = Pusher(
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=serving_model_dir_lite)))

  components = [
      example_gen, statistics_gen, schema_gen, example_validator, transform,
      trainer, model_resolver, evaluator, pusher
  ]

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      beam_pipeline_args=beam_pipeline_args)


# To run this pipeline from the python CLI:
#   $python cifar_pipeline_native_keras.py
if __name__ == '__main__':

  loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
  for logger in loggers:
    logger.setLevel(logging.INFO)
  logging.getLogger().setLevel(logging.INFO)

  absl.logging.set_verbosity(absl.logging.FATAL)
  BeamDagRunner().run(
      _create_pipeline(
          pipeline_name=_pipeline_name,
          pipeline_root=_pipeline_root,
          data_root=_data_root,
          module_file=_module_file,
          serving_model_dir_lite=_serving_model_dir_lite,
          metadata_path=_metadata_path,
          labels_path=_labels_path,
          beam_pipeline_args=_beam_pipeline_args))

utils file:

# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python source file includes CIFAR10 utils for Keras model.

The utilities in this file are used to build a model with native Keras.
This module file will be used in Transform and generic Trainer.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from typing import List, Text
import absl
import tensorflow as tf
import tensorflow_transform as tft

from tfx.components.trainer.fn_args_utils import DataAccessor
from tfx.components.trainer.fn_args_utils import FnArgs
from tfx.components.trainer.rewriting import converters
from tfx.components.trainer.rewriting import rewriter
from tfx.components.trainer.rewriting import rewriter_factory
from tfx.dsl.io import fileio
from tfx_bsl.tfxio import dataset_options

# import flatbuffers
# from tflite_support import metadata_schema_py_generated as _metadata_fb
# from tflite_support import metadata as _metadata

# When training on the whole dataset use following constants instead.
# This setting should give ~91% accuracy on the whole test set
# _TRAIN_DATA_SIZE = 50000
# _EVAL_DATA_SIZE = 10000
# _TRAIN_BATCH_SIZE = 64
# _EVAL_BATCH_SIZE = 64
# _CLASSIFIER_LEARNING_RATE = 3e-4
# _FINETUNE_LEARNING_RATE = 5e-5
# _CLASSIFIER_EPOCHS = 12

_TRAIN_DATA_SIZE = 128
_EVAL_DATA_SIZE = 128
_TRAIN_BATCH_SIZE = 32
_EVAL_BATCH_SIZE = 32
_CLASSIFIER_LEARNING_RATE = 1e-3
_FINETUNE_LEARNING_RATE = 7e-6
_CLASSIFIER_EPOCHS = 30

_IMAGE_KEY = 'image'
_LABEL_KEY = 'label'

_TFLITE_MODEL_NAME = 'tflite'


def _transformed_name(key):
  return key + '_xf'


def _get_serve_image_fn(model, tf_transform_output):
  """Returns a function that feeds the input tensor into the model."""
  model.tft_layer = tf_transform_output.transform_features_layer()
  @tf.function
  def serve_image_fn(serialized_tf_examples):
    
    feature_spec = tf_transform_output.raw_feature_spec()
    feature_spec.pop(_LABEL_KEY)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)

    transformed_features = model.tft_layer(parsed_features)
    
    return model(transformed_features)

  return serve_image_fn


def _image_augmentation(image_features):
  """Perform image augmentation on batches of images .

  Args:
    image_features: a batch of image features

  Returns:
    The augmented image features
  """
  batch_size = tf.shape(image_features)[0]
  image_features = tf.image.random_flip_left_right(image_features)
  image_features = tf.image.resize_with_crop_or_pad(image_features, 250, 250)
  image_features = tf.image.random_crop(image_features,
                                        (batch_size, 224, 224, 3))
  return image_features


def _data_augmentation(feature_dict):
  """Perform data augmentation on batches of data.

  Args:
    feature_dict: a dict containing features of samples

  Returns:
    The feature dict with augmented features
  """
  image_features = feature_dict[_transformed_name(_IMAGE_KEY)]
  image_features = _image_augmentation(image_features)
  feature_dict[_transformed_name(_IMAGE_KEY)] = image_features
  return feature_dict


def _input_fn(file_pattern: List[Text],
              data_accessor: DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              is_train: bool = False,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    is_train: Whether the input dataset is train split or not.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  dataset = data_accessor.tf_dataset_factory(
      file_pattern,
      dataset_options.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)),
      tf_transform_output.transformed_metadata.schema)
  # Apply data augmentation. We have to do data augmentation here because
  # we need to apply data agumentation on-the-fly during training. If we put
  # it in Transform, it will only be applied once on the whole dataset, which
  # will lose the point of data augmentation.
  if is_train:
    dataset = dataset.map(lambda x, y: (_data_augmentation(x), y))

  return dataset


def _freeze_model_by_percentage(model: tf.keras.Model, percentage: float):
  """Freeze part of the model based on specified percentage.

  Args:
    model: The keras model need to be partially frozen
    percentage: the percentage of layers to freeze

  Raises:
    ValueError: Invalid values.
  """
  if percentage < 0 or percentage > 1:
    raise ValueError('Freeze percentage should between 0.0 and 1.0')

  if not model.trainable:
    raise ValueError(
        'The model is not trainable, please set model.trainable to True')

  num_layers = len(model.layers)
  num_layers_to_freeze = int(num_layers * percentage)
  for idx, layer in enumerate(model.layers):
    if idx < num_layers_to_freeze:
      layer.trainable = False
    else:
      layer.trainable = True


def _build_keras_model() -> tf.keras.Model:
  """Creates a Image classification model with MobileNet backbone.

  Returns:
    The image classifcation Keras Model and the backbone MobileNet model
  """
  # We create a MobileNet model with weights pre-trained on ImageNet.
  # We remove the top classification layer of the MobileNet, which was
  # used for classifying ImageNet objects. We will add our own classification
  # layer for CIFAR10 later. We use average pooling at the last convolution
  # layer to get a 1D vector for classifcation, which is consistent with the
  # origin MobileNet setup
  base_model = tf.keras.applications.MobileNet(
      input_shape=(224, 224, 3),
      include_top=False,
      weights='imagenet',
      pooling='avg')
  base_model.input_spec = None

  # We add a Dropout layer at the top of MobileNet backbone we just created to
  # prevent overfiting, and then a Dense layer to classifying CIFAR10 objects
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(
          input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)),
      base_model,
      tf.keras.layers.Dropout(0.1),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

  # Freeze the whole MobileNet backbone to first train the top classifer only
  _freeze_model_by_percentage(base_model, 1.0)

  model.compile(
      loss='sparse_categorical_crossentropy',
      optimizer=tf.keras.optimizers.RMSprop(lr=_CLASSIFIER_LEARNING_RATE),
      metrics=['sparse_categorical_accuracy'])
  model.summary(print_fn=absl.logging.info)

  return model, base_model


# TFX Transform will call this function.
def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}

  # tf.io.decode_png function cannot be applied on a batch of data.
  # We have to use tf.map_fn
  image_features = tf.map_fn(
      lambda x: tf.io.decode_png(x[0], channels=3),
      inputs[_IMAGE_KEY],
      dtype=tf.uint8)
  # image_features = tf.cast(image_features, tf.float32)
  image_features = tf.image.resize(image_features, [224, 224])
  image_features = tf.keras.applications.mobilenet.preprocess_input(
      image_features)

  outputs[_transformed_name(_IMAGE_KEY)] = image_features
  #outputs["input_1"] = image_features
  # TODO(b/157064428): Support label transformation for Keras.
  # Do not apply label transformation as it will result in wrong evaluation.
  outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY]

  return outputs

# TFX Trainer will call this function.
def run_fn(fn_args: FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.

  Raises:
    ValueError: if invalid inputs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  baseline_path = fn_args.base_model
  if baseline_path is not None:
    model = tf.keras.models.load_model(os.path.join(baseline_path))
  else:

    train_dataset = _input_fn(
        fn_args.train_files,
        fn_args.data_accessor,
        tf_transform_output,
        is_train=True,
        batch_size=_TRAIN_BATCH_SIZE)
    eval_dataset = _input_fn(
        fn_args.eval_files,
        fn_args.data_accessor,
        tf_transform_output,
        is_train=False,
        batch_size=_EVAL_BATCH_SIZE)

    model, base_model = _build_keras_model()

    absl.logging.info('Tensorboard logging to {}'.format(fn_args.model_run_dir))
    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=fn_args.model_run_dir, update_freq='batch')

    # Our training regime has two phases: we first freeze the backbone and train
    # the newly added classifier only, then unfreeze part of the backbone and
    # fine-tune with classifier jointly.
    steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE)
    total_epochs = int(fn_args.train_steps / steps_per_epoch)
    if _CLASSIFIER_EPOCHS > total_epochs:
      raise ValueError('Classifier epochs is greater than the total epochs')

    absl.logging.info('Start training the top classifier')
    model.fit(
        train_dataset,
        epochs=_CLASSIFIER_EPOCHS,
        steps_per_epoch=steps_per_epoch,
        validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps,
        callbacks=[tensorboard_callback])

    absl.logging.info('Start fine-tuning the model')
    # Unfreeze the top MobileNet layers and do joint fine-tuning
    _freeze_model_by_percentage(base_model, 0.9)

    # We need to recompile the model because layer properties have changed
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=tf.keras.optimizers.RMSprop(lr=_FINETUNE_LEARNING_RATE),
        metrics=['sparse_categorical_accuracy'])
    model.summary(print_fn=absl.logging.info)

    model.fit(
        train_dataset,
        initial_epoch=_CLASSIFIER_EPOCHS,
        epochs=total_epochs,
        steps_per_epoch=steps_per_epoch,
        validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps,
        callbacks=[tensorboard_callback])

  # Prepare the TFLite model used for serving in MLKit
  signatures = {
      'serving_default':
          _get_serve_image_fn(model,tf_transform_output).get_concrete_function(
              tf.TensorSpec(
                  shape=[None],
                  dtype=tf.string,
                  name=_IMAGE_KEY))
  }

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir)
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

  # tfrw = rewriter_factory.create_rewriter(
  #     rewriter_factory.TFLITE_REWRITER,
  #     name='tflite_rewriter')
  # converters.rewrite_saved_model(temp_saving_model_dir,
  #                                fn_args.serving_model_dir, tfrw,
  #                                rewriter.ModelType.TFLITE_MODEL)

  # # Add necessary TFLite metadata to the model in order to use it within MLKit
  # # TODO(dzats@): Handle label map file path more properly, currently
  # # hard-coded.
  # tflite_model_path = os.path.join(fn_args.serving_model_dir,
  #                                  _TFLITE_MODEL_NAME)
  # # TODO(dzats@): Extend the TFLite rewriter to be able to add TFLite metadata
  # #@ to the model.
  # _write_metadata(
  #     model_path=tflite_model_path,
  #     label_map_path=fn_args.custom_config['labels_path'],
  #     mean=[127.5],
  #     std=[127.5])

  # fileio.rmtree(temp_saving_model_dir)

Solution

  • Ok I found the answer. Because the model is expecting the input_1 name, then in _get_serve_image_fn, I need to create the dictionary key, such as:

    def _get_serve_image_fn(model, tf_transform_output):
      """Returns a function that feeds the input tensor into the model."""
      model.tft_layer = tf_transform_output.transform_features_layer()
      @tf.function
      def serve_image_fn(serialized_tf_examples):
        
        feature_spec = tf_transform_output.raw_feature_spec()
        feature_spec.pop(_LABEL_KEY)
        parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
    
        transformed_features = model.tft_layer(parsed_features)
        transformed_features[model.get_layer('mobilenet_1.00_224').layers[0].name] = transformed_features[_transformed_name(_IMAGE_KEY)]
        del transformed_features[_transformed_name(_IMAGE_KEY)]
        
        return model(transformed_features)
    
      return serve_image_fn