pythonkerasdeep-learninglstmcrf

adding CRF layer to LSTM flattens out learning curve


I have implemented a bi-LSTM named entity tagger in keras with tensorflow backend (tf version 1.13.1). The task of the network, given a sequence of word tokens, is to tag every element of the sequence with an entity type label. I subsequently added a Conditional Random Field (CRF) layer on top on the network; in state-of-the-art named entity taggers, CRF is typically used to improve the result of bi-LSTM by imposing adjacency constraints on neighboring elements in the sequence.

This is the network code. The loop performs custom mini-batch training due to the nature of the data.

from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional, Dropout
from keras import optimizers
from keras_contrib.layers import CRF

model = Sequential()
model.add(Bidirectional(LSTM(params['hidden_size'], return_sequences=True), input_shape=(params['max_seq_len'], params['emb_size'])))
model.add(Dropout(params['dropout']))
model.add(TimeDistributed(Dense(params['n_classes'], activation=params['activation'])))
model.add(CRF(params['n_classes'], sparse_target=True))

model.compile(loss=crf.loss_function, optimizer='sgd'), metrics=['mse']) 

for epoch in range(params['n_epochs']):

    for X,y in data:

        X = X.reshape(1, params['max_seq_len'], params['emb_size'])
        y = y.reshape(1, params['max_seq_len'], 1)

        history = model.fit(X, y, epochs=1, batch_size=params['batch_size'], verbose=1)

I expected the addition of CRF layer to improve the network's performance; however, it not only did not improve the results, but has also completely flattened out the learning curve, as measured by the mean squared error (shown in the bottom panel plot of mean squared error below, averaged over each epoch). This is puzzling, especially because the loss keeps on decreasing (as shown in the top panel plot of loss below, likewise averaged over each epoch). In addition, the model's performance on the test set was brought down to zero.

enter image description here

Is there an error in the implementation on the CRF? Or it be because the CRF loss function is not measured via mean squared error?

A simulated data set that can be used to reproduce the problem is provided in this SO question (same implementation as here but without CRF layer).


Solution

  • Your network looks fine.
    A few changes and iterations I would try are following -

    1. Instead off using mse use crf_marginal_accuracy and initialize crf with learn_mode='marginal'.
      This is because you are using the model for NER and not using one-hot encoders as your sparse_target=True.
    2. Use adam optimizer, cause sgd does not take whole dataset into consideration while computing, but for a TimeDistributed model where you have more than 3 dims, sgd does not optimize.
    3. Might want to use crf_loss for loss function.