python-3.xtensorflow2.0tf.kerascrf

How to add CRF layer in a tensorflow sequential model?


I am trying to implement a CRF layer in a TensorFlow sequential model for a NER problem. I am not sure how to do it. Previously when I implemented CRF, I used CRF from keras with tensorflow as backend i.e. I created the entire model in keras instead of tensorflow and then passed the entire model through CRF. It worked.

But now I want to develop the model in Tensorflow as tensorflow2.0.0 beta already has keras inbuilt in it and I am trying to build a sequential layer and add CRF layer after a bidirectional lstm layer. Although I am not sure how to do that. I have gone through the CRF documentation in tensorflow-addons and it contains different functions such as forward CRF etc etc but not sure how to implement them as a layer ? I am wondering is it possible at all to implement a CRF layer inside a sequential tensorflow model or do I need to build the model graph from scratch and then use CRF functions ? Can anyone please help me with it. Thanks in advance


Solution

  • enter image description here

    In the training process:

    You can refer to this API:

    tfa.text.crf_log_likelihood(
        inputs,
        tag_indices,
        sequence_lengths,
        transition_params=None
    )
    

    The inputs are the unary potentials(just like that in the logistic regression, and you can refer to this answer) and here in your case, they are the logits(it is usually not the distributions after the softmax activation function) or states of the BiLSTM for each character in the encoder(P1, P2, P3, P4 in the diagram above; ).

    The tag_indices are the target tag indices, and the sequence_lengths represent the sequence lengths in a batch.

    The transition_params are the binary potentials(also how the tag transits from one time step to the next), you can create the matrix yourself or you just let the API do it for you.

    In the inference process:
    You just utilize this API:

    tfa.text.viterbi_decode(
        score,
        transition_params
    ) 
    

    The score stands for the same input like that in the training(the P1, P2, P3, P4 states) and the transition_params are also that trained in the training process.