I have implemented a bi-LSTM named entity tagger in keras with tensorflow backend (tf version 1.13.1). The task of the network, given a sequence of word tokens, is to tag every element of the sequence with an entity type label. I subsequently added a Conditional Random Field (CRF) layer on top on the network; in state-of-the-art named entity taggers, CRF is typically used to improve the result of bi-LSTM by imposing adjacency constraints on neighboring elements in the sequence.
This is the network code. The loop performs custom mini-batch training due to the nature of the data.
from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional, Dropout
from keras import optimizers
from keras_contrib.layers import CRF
model = Sequential()
model.add(Bidirectional(LSTM(params['hidden_size'], return_sequences=True), input_shape=(params['max_seq_len'], params['emb_size'])))
model.add(Dropout(params['dropout']))
model.add(TimeDistributed(Dense(params['n_classes'], activation=params['activation'])))
model.add(CRF(params['n_classes'], sparse_target=True))
model.compile(loss=crf.loss_function, optimizer='sgd'), metrics=['mse'])
for epoch in range(params['n_epochs']):
for X,y in data:
X = X.reshape(1, params['max_seq_len'], params['emb_size'])
y = y.reshape(1, params['max_seq_len'], 1)
history = model.fit(X, y, epochs=1, batch_size=params['batch_size'], verbose=1)
I expected the addition of CRF layer to improve the network's performance; however, it not only did not improve the results, but has also completely flattened out the learning curve, as measured by the mean squared error (shown in the bottom panel plot of mean squared error below, averaged over each epoch). This is puzzling, especially because the loss keeps on decreasing (as shown in the top panel plot of loss below, likewise averaged over each epoch). In addition, the model's performance on the test set was brought down to zero.
Is there an error in the implementation on the CRF? Or it be because the CRF loss function is not measured via mean squared error?
A simulated data set that can be used to reproduce the problem is provided in this SO question (same implementation as here but without CRF layer).
Your network looks fine.
A few changes and iterations I would try are following -
mse
use crf_marginal_accuracy
and initialize crf with learn_mode='marginal'
.sparse_target=True
. adam
optimizer, cause sgd does not take whole dataset into consideration while computing, but for a TimeDistributed model where you have more than 3 dims, sgd does not optimize.