pythontensorflowkerastensorflow2.0ctc

Tensorflow Callback as Custom Metric for CTC


In an attempt to yield more metrics during the training of my model (written in TensorFlow version 2.1.0), like the Character Error Rate (CER) and Word Error Rate (WER), I created a callback to pass to the fit function of my model. It is able to generate the CER and WER at the end of an epoch.

It's my second choice as I wanted to create a custom metric for this, but you can only use keras backend functionality for custom metrics. Does anyone have any advice on how to convert the callback below into a Custom Metric (which can then be calculated during training on the validation and/or training data)?

Some roadblocks I encountered are:

class Metrics(tf.keras.callbacks.Callback):
    def __init__(self, valid_data, steps):
        """
        valid_data is a TFRecordDataset with batches of 100 elements per batch, shuffled and repeated infinitely. 
        steps define the amount of batches per epoch
        """
        super(Metrics, self).__init__()
        self.valid_data = valid_data
        self.steps = steps

    def on_train_begin(self, logs={}):
        self.cer = []
        self.wer = []
        
    def on_epoch_end(self, epoch, logs={}):

        imgs = []
        labels = []
        for idx, (img, label) in enumerate(self.valid_data.as_numpy_iterator()):
            if idx >= self.steps:
                break
            imgs.append(img)
            labels.extend(label)

        imgs = np.array(imgs)
        labels = np.array(labels)

        out = self.model.predict((batch for batch in imgs))        
        input_length = len(max(out, key=len))

        out = np.asarray(out)
        out_len = np.asarray([input_length for _ in range(len(out))])

        decode, log = K.ctc_decode(out,
                                    out_len,
                                    greedy=True)

        decode = [[[int(p) for p in x if p != -1] for x in y] for y in decode][0]

        for (pred, lab) in zip(decode, labels):
        
            dist = editdistance.eval(pred, lab)
            self.cer.append(dist / (max(len(pred), len(lab))))
            self.wer.append(not np.array_equal(pred, lab))

        
        print("Mean CER: {}".format(np.mean([self.cer], axis=1)[0]))
        print("Mean WER: {}".format(np.mean([self.wer], axis=1)[0]))

Solution

  • Solved in TF 2.3.1, but should apply for previous versions of 2.x as well.

    Some remarks:

    class CERMetric(tf.keras.metrics.Metric):
        """
        A custom Keras metric to compute the Character Error Rate
        """
        def __init__(self, name='CER_metric', **kwargs):
            super(CERMetric, self).__init__(name=name, **kwargs)
            self.cer_accumulator = self.add_weight(name="total_cer", initializer="zeros")
            self.counter = self.add_weight(name="cer_count", initializer="zeros")
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            input_shape = K.shape(y_pred)
            input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')
    
            decode, log = K.ctc_decode(y_pred,
                                        input_length,
                                        greedy=True)
    
            decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
            y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))
    
            decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
            distance = tf.edit_distance(decode, y_true_sparse, normalize=True)
    
            self.cer_accumulator.assign_add(tf.reduce_sum(distance))
            self.counter.assign_add(len(y_true))
    
        def result(self):
            return tf.math.divide_no_nan(self.cer_accumulator, self.counter)
    
        def reset_states(self):
            self.cer_accumulator.assign(0.0)
            self.counter.assign(0.0)
    
    class WERMetric(tf.keras.metrics.Metric):
        """
        A custom Keras metric to compute the Word Error Rate
        """
        def __init__(self, name='WER_metric', **kwargs):
            super(WERMetric, self).__init__(name=name, **kwargs)
            self.wer_accumulator = self.add_weight(name="total_wer", initializer="zeros")
            self.counter = self.add_weight(name="wer_count", initializer="zeros")
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            input_shape = K.shape(y_pred)
            input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')
    
            decode, log = K.ctc_decode(y_pred,
                                        input_length,
                                        greedy=True)
    
            decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
            y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))
    
            decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
            distance = tf.edit_distance(decode, y_true_sparse, normalize=True)
            
            correct_words_amount = tf.reduce_sum(tf.cast(tf.not_equal(distance, 0), tf.float32))
    
            self.wer_accumulator.assign_add(correct_words_amount)
            self.counter.assign_add(len(y_true))
    
        def result(self):
            return tf.math.divide_no_nan(self.wer_accumulator, self.counter)
    
        def reset_states(self):
            self.wer_accumulator.assign(0.0)
            self.counter.assign(0.0)