Having the following model written in the sequential API:
config = {
'learning_rate': 0.001,
'lstm_neurons':32,
'lstm_activation':'tanh',
'dropout_rate': 0.08,
'batch_size': 128,
'dense_layers':[
{'neurons': 32, 'activation': 'relu'},
{'neurons': 32, 'activation': 'relu'},
]
}
def get_model(num_features, output_size):
opt = Adam(learning_rate=0.001)
model = Sequential()
model.add(Input(shape=[None,num_features], dtype=tf.float32, ragged=True))
model.add(LSTM(config['lstm_neurons'], activation=config['lstm_activation']))
model.add(BatchNormalization())
if 'dropout_rate' in config:
model.add(Dropout(config['dropout_rate']))
for layer in config['dense_layers']:
model.add(Dense(layer['neurons'], activation=layer['activation']))
model.add(BatchNormalization())
if 'dropout_rate' in layer:
model.add(Dropout(layer['dropout_rate']))
model.add(Dense(output_size, activation='sigmoid'))
model.compile(loss='mse', optimizer=opt, metrics=['mse'])
print(model.summary())
return model
When using a distributed training framework, I need to convert the syntax to use model subclassing instead. I've looked at the docs but couldn't figure out how to do it.
Here is one equivalent subclassed implementation. Though I didn't test.
import tensorflow as tf
# your config
config = {
'learning_rate': 0.001,
'lstm_neurons':32,
'lstm_activation':'tanh',
'dropout_rate': 0.08,
'batch_size': 128,
'dense_layers':[
{'neurons': 32, 'activation': 'relu'},
{'neurons': 32, 'activation': 'relu'},
]
}
# Subclassed API Model
class MySubClassed(tf.keras.Model):
def __init__(self, output_size):
super(MySubClassed, self).__init__()
self.lstm = tf.keras.layers.LSTM(config['lstm_neurons'],
activation=config['lstm_activation'])
self.bn = tf.keras.layers.BatchNormalization()
if 'dropout_rate' in config:
self.dp1 = tf.keras.layers.Dropout(config['dropout_rate'])
self.dp2 = tf.keras.layers.Dropout(config['dropout_rate'])
self.dp3 = tf.keras.layers.Dropout(config['dropout_rate'])
for layer in config['dense_layers']:
self.dense1 = tf.keras.layers.Dense(layer['neurons'],
activation=layer['activation'])
self.bn1 = tf.keras.layers.BatchNormalization()
self.dense2 = tf.keras.layers.Dense(layer['neurons'],
activation=layer['activation'])
self.bn2 = tf.keras.layers.BatchNormalization()
self.out = tf.keras.layers.Dense(output_size,
activation='sigmoid')
def call(self, inputs, training=True, **kwargs):
x = self.lstm(inputs)
x = self.bn(x)
if 'dropout_rate' in config:
x = self.dp1(x)
x = self.dense1(x)
x = self.bn1(x)
if 'dropout_rate' in config:
x = self.dp2(x)
x = self.dense2(x)
x = self.bn2(x)
if 'dropout_rate' in config:
x = self.dp3(x)
return self.out(x)
# A convenient way to get model summary
# and plot in subclassed api
def build_graph(self, raw_shape):
x = tf.keras.layers.Input(shape=(None, raw_shape),
ragged=True)
return tf.keras.Model(inputs=[x],
outputs=self.call(x))
Build and compile the mdoel
s = MySubClassed(output_size=1)
s.compile(
loss = 'mse',
metrics = ['mse'],
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001))
Pass some tensor to create weights (check).
raw_input = (16, 16, 16)
y = s(tf.ones(shape=(raw_input)))
print("weights:", len(s.weights))
print("trainable weights:", len(s.trainable_weights))
weights: 21
trainable weights: 15
Summarize and visualize the model graph.
s.build_graph(16).summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, None, 16)] 0
_________________________________________________________________
lstm (LSTM) (None, 32) 6272
_________________________________________________________________
batch_normalization (BatchNo (None, 32) 128
_________________________________________________________________
dropout (Dropout) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 32) 1056
_________________________________________________________________
batch_normalization_3 (Batch (None, 32) 128
_________________________________________________________________
dropout_1 (Dropout) (None, 32) 0
_________________________________________________________________
dense_3 (Dense) (None, 32) 1056
_________________________________________________________________
batch_normalization_4 (Batch (None, 32) 128
_________________________________________________________________
dropout_2 (Dropout) (None, 32) 0
_________________________________________________________________
dense_4 (Dense) (None, 1) 33
=================================================================
Total params: 8,801
Trainable params: 8,609
Non-trainable params: 192
tf.keras.utils.plot_model(
s.build_graph(16),
show_shapes=True,
show_dtype=True,
show_layer_names=True,
rankdir="TB",
)