kerasocrdecodingctc

How can I add the decode_batch_predictions() method into the Keras Captcha OCR model?


The current Keras Captcha OCR model returns a CTC encoded output, which requires decoding after inference.

To decode this, one needs to run a decoding utility function after inference as a separate step.

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

The decoded utility function uses keras.backend.ctc_decode, which in turn uses either a greedy or beam search decoder.

# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_length
    ]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

I would like to train a Captcha OCR model using Keras that returns the CTC decoded as an output, without requiring an additional decoding step after inference.

How would I achieve this?


Solution

  • The most robust way to achieve this is by adding a method which is called as part of the model definition:

    def CTCDecoder():
      def decoder(y_pred):
        input_shape = tf.keras.backend.shape(y_pred)
        input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast(
            input_shape[1], 'float32')
        unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0]
        unpadded_shape = tf.keras.backend.shape(unpadded)
        padded = tf.pad(unpadded,
                        paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]],
                        constant_values=-1)
        return padded
    
    return tf.keras.layers.Lambda(decoder, name='decode')
    

    Then defining the model as follows:

    prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))
    

    Credit goes to tulasiram58827.

    This implementation supports exporting to TFLite, but only float32. Quantized (int8) TFLite export is still throwing an error, and is an open ticket with TF team.