Looks like the million dollars question. I have the model below built by sub classing Model
in Keras.
Model trains fine and have good performance but I cannot find a way to save and restore the model without incurring a significant performance loss. I track AUC on ROC curves for anomaly detection, and the ROC curve after loading the model is worse than before, using exactly the same validation data set.
I suspect the problem to come from the BatchNormalization, but I could be wrong.
I've tried several option:
This works but leads to performance drop.
model.save() / tf.keras.models.load()
This works but also lead to performance drop:
model.save_weights() / model.load_weights()
This does not work and I get the following error:
tf.saved_model.save() / tf.saved_model.load()
AttributeError: '_UserObject' object has no attribute 'predict'
This does not work either, as Subclassed model do not support json export:
model.to_json()
Here is the model:
class Deep_Seq2Seq_Detector(Model):
def __init__(self, flight_len, param_len, hidden_state=16):
super(Deep_Seq2Seq_Detector, self).__init__()
self.input_dim = (None, flight_len, param_len)
self._name_ = "LSTM"
self.units = hidden_state
self.regularizer0 = tf.keras.Sequential([
layers.BatchNormalization()
])
self.encoder1 = layers.LSTM(self.units,
return_state=False,
return_sequences=True,
#activation="tanh",
name='encoder1',
input_shape=self.input_dim)#,
#kernel_regularizer= tf.keras.regularizers.l1(),
#)
self.regularizer1 = tf.keras.Sequential([
layers.BatchNormalization(),
layers.Activation("tanh")
])
self.encoder2 = layers.LSTM(self.units,
return_state=False,
return_sequences=True,
#activation="tanh",
name='encoder2')#,
#kernel_regularizer= tf.keras.regularizers.l1()
#) # input_shape=(None, self.input_dim[1],self.units),
self.regularizer2 = tf.keras.Sequential([
layers.BatchNormalization(),
layers.Activation("tanh")
])
self.encoder3 = layers.LSTM(self.units,
return_state=True,
return_sequences=False,
activation="tanh",
name='encoder3')#,
#kernel_regularizer= tf.keras.regularizers.l1(),
#) # input_shape=(None, self.input_dim[1],self.units),
self.repeat = layers.RepeatVector(self.input_dim[1])
self.decoder = layers.LSTM(self.units,
return_sequences=True,
activation="tanh",
name="decoder",
input_shape=(self.input_dim[1],self.units))
self.dense = layers.TimeDistributed(layers.Dense(self.input_dim[2]))
@tf.function
def call(self, x):
# Encoder
x0 = self.regularizer0(x)
x1 = self.encoder1(x0)
x11 = self.regularizer1(x1)
x2 = self.encoder2(x11)
x22 = self.regularizer2(x2)
output, hs, cs = self.encoder3(x22)
# see https://www.tensorflow.org/guide/keras/rnn
encoded_state = [hs, cs]
repeated_vec = self.repeat(output)
# Decoder
decoded = self.decoder(repeated_vec, initial_state=encoded_state)
output_decoder = self.dense(decoded)
return output_decoder
I've seen Git threads, but no straight answer: https://github.com/keras-team/keras/issues/4875
Did anyone found a solution ? Do I have to use the Functional or Sequential API instead ?
It seems the problem was coming from the Subclassing API
.
I reconstructed the exact same model using the Functional API
and now model.save
/ model.load
yields similar results.