tensorflowkerastensorflow2.0tf.keras

How to save and reload a Subclassed model in TF 2.6.0 / Python 3.9.7 wihtout performance drop?


Looks like the million dollars question. I have the model below built by sub classing Model in Keras.

Model trains fine and have good performance but I cannot find a way to save and restore the model without incurring a significant performance loss. I track AUC on ROC curves for anomaly detection, and the ROC curve after loading the model is worse than before, using exactly the same validation data set.

I suspect the problem to come from the BatchNormalization, but I could be wrong.

I've tried several option:

This works but leads to performance drop.

model.save() / tf.keras.models.load()

This works but also lead to performance drop:

model.save_weights() / model.load_weights()

This does not work and I get the following error:

tf.saved_model.save() / tf.saved_model.load()

AttributeError: '_UserObject' object has no attribute 'predict'

This does not work either, as Subclassed model do not support json export:

model.to_json()

Here is the model:

class Deep_Seq2Seq_Detector(Model):
  def __init__(self, flight_len, param_len, hidden_state=16):
    super(Deep_Seq2Seq_Detector, self).__init__()
    self.input_dim = (None, flight_len, param_len)
    self._name_ = "LSTM"
    self.units = hidden_state
    
    self.regularizer0 = tf.keras.Sequential([
        layers.BatchNormalization()
        ])
    
    self.encoder1 = layers.LSTM(self.units,
                  return_state=False,
                  return_sequences=True,
                  #activation="tanh",
                  name='encoder1',
                  input_shape=self.input_dim)#,
                  #kernel_regularizer= tf.keras.regularizers.l1(),
                  #)
    
    self.regularizer1 = tf.keras.Sequential([
        layers.BatchNormalization(),
        layers.Activation("tanh")
        ])
    
    self.encoder2 = layers.LSTM(self.units,
                  return_state=False,
                  return_sequences=True,
                  #activation="tanh",
                  name='encoder2')#,
                  #kernel_regularizer= tf.keras.regularizers.l1()
                  #) #                    input_shape=(None, self.input_dim[1],self.units),
    
    self.regularizer2 = tf.keras.Sequential([
        layers.BatchNormalization(),
        layers.Activation("tanh")
        ])
    
    self.encoder3 = layers.LSTM(self.units,
                  return_state=True,
                  return_sequences=False,
                  activation="tanh",
                  name='encoder3')#,
                  #kernel_regularizer= tf.keras.regularizers.l1(),
                  #) #                   input_shape=(None, self.input_dim[1],self.units),
    
    self.repeat = layers.RepeatVector(self.input_dim[1])
    
    self.decoder = layers.LSTM(self.units,
                  return_sequences=True,
                  activation="tanh",
                  name="decoder",
                  input_shape=(self.input_dim[1],self.units))
    
    self.dense = layers.TimeDistributed(layers.Dense(self.input_dim[2]))

  @tf.function 
  def call(self, x):
    
    # Encoder
    x0 = self.regularizer0(x)
    x1 = self.encoder1(x0)
    x11 = self.regularizer1(x1)
    
    x2 = self.encoder2(x11)
    x22 = self.regularizer2(x2)
    
    output, hs, cs = self.encoder3(x22)
    
    # see https://www.tensorflow.org/guide/keras/rnn 
    encoded_state = [hs, cs] 
    repeated_vec = self.repeat(output)
    
    # Decoder
    decoded = self.decoder(repeated_vec, initial_state=encoded_state)
    output_decoder = self.dense(decoded)

    return output_decoder

I've seen Git threads, but no straight answer: https://github.com/keras-team/keras/issues/4875

Did anyone found a solution ? Do I have to use the Functional or Sequential API instead ?


Solution

  • It seems the problem was coming from the Subclassing API.

    I reconstructed the exact same model using the Functional API and now model.save / model.load yields similar results.