pythonmachine-learningkerasnlplanguage-translation

Imposing grammar rules manually on Sequence2Sequence keras model


I have a fairly standard Sequence to sequence translator in keras, which looks like this:

# create model 

encoder_inputs = Input(shape=(None,))
en_x=  Embedding(num_encoder_tokens, EMBEDDING_SIZE)(encoder_inputs)
encoder = LSTM(50, return_state=True)
encoder_outputs, state_h, state_c = encoder(en_x)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]


# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dex=  Embedding(num_decoder_tokens, EMBEDDING_SIZE)
final_dex= dex(decoder_inputs)

decoder_lstm = LSTM(50, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(final_dex, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)


model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.05)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])

I know it isn't a great idea, but the data I am trying to translate is not spoken language and I want to impose further rules on the decoded sequence, which is that "any word should only occur once in the decoded sequence" among others. The rule does not apply to the sequence being encoded.

The data I am using to train the model does already adhere to this rule, but the current output of the model does not. (I know this rule doesn't really make sense language-wise)

Is there a way to do this, and if so how?


Solution

  • Why not in the decoder check for duplicating words then stop the decoding if it occurs. Add rules in the char = target_index_word[word_index] decoded_sentence += ' '+char part of the decoder

    def get_predicted_sentence(input_seq):
        # Encode the input as state vectors.
        enc_output, enc_h, enc_c = encoder_model.predict(input_seq)
        # Generate empty target sequence of length 1.
        target_seq = np.zeros((1,1))
        
        # Populate the first character of target sequence with the start character.
        target_seq[0, 0] = target_word_index['sos']
        
        # Sampling loop for a batch of sequences
        # (to simplify, here we assume a batch of size 1).
        stop_condition = False
        decoded_sentence = ""
        
        count=0
        while not stop_condition:
            count+=1
            if count>1000:
                print('count exceeded')
                stop_condition=True
            output_words, dec_h, dec_c = decoder_model.predict([target_seq] + [enc_output, enc_h, enc_c ])
            #print(output_tokens)
            word_index = np.argmax(output_words[0, -1, :])
            char=""
            if word_index in target_index_word:
                char = target_index_word[word_index]
                decoded_sentence += ' '+char
                print(decoded_sentence)
            else:
                stop_condition=True
            if char == 'eos' or len(decoded_sentence) >= max_input_len:
                stop_condition = True
            
            # Update the target sequence (of length 1).
            target_seq = np.zeros((1,1))
            target_seq[0, 0] = word_index
            print(target_seq[0,0])
            # Update states
            enc_h, enc_c = dec_h, dec_c
        
        return decoded_sentence