I'm trying to implement a custom model in tensorflow extending the tf.keras.Model class.
I need a way to add n stacked LSTM layers to the model.
For instance, assuming the following implementation
class CustomizedLSTM(tf.keras.Model):
def __init__(self, num_hidden_layers, vocab_size):
super(CustomizedLSTM, self).__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size,300)
self.first_lstm = tf.keras.layers.LSTM(256, activation="relu")
self.first_dense = tf.keras.layers.Dense(64, activation="relu")
self.classification_layer = tf.keras.layers.Dense(1, activation="sigmoid")
def call(self, inputs):
x = self.embedding(inputs)
x = self.first_lstm(x)
x = self.first_dense(x)
return self.classification_layer(x)
I would like to add the possibility to customize the number of hidden LSTM layers. in other words I would like create a model with num_hidden_layers stacked LSTMs.
Is it possible? Can you please help me?
class CustomizedLSTM(tf.keras.Model):
def __init__(self, num_hidden_layers, dim_per_hidden, vocab_size):
self.lstms = []
super(CustomizedLSTM, self).__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size,300)
To stack multiple LSTMs, it is mandatory for all the lower LSTMs to have
return_sequence=True, as they will be fed as input to the next LSTM.
for i in range(num_hidden_layers):
self.lstms.append(tf.keras.layers.LSTM(dim_per_hidden[i], activation="relu", return_sequences=True))
The last lstm with return_sequences=False, you can change it according to
your needs.
self.lstms.append(tf.keras.layers.LSTM(dim_per_hidden[i], activation="relu", return_sequences=False))
self.first_dense = tf.keras.layers.Dense(64, activation="relu")
self.classification_layer = tf.keras.layers.Dense(1, activation="sigmoid")
def call(self, inputs):
x = self.embedding(inputs)
for layer in self.lstms:
x = layer(x)
x = self.first_dense(x)
return self.classification_layer(x)
I added another parameter that you can consider - "dim_per_hidden": This parameter is a list of numbers to decide the number of neurons of each lstm layer