pythontensorflowkerasocrctc

Custom metrics work with eager execution only


Based on the answer to a question I asked earlier, I'm trying to make custom metrics word_accuracy and char_accuracy work with CRNN-CTC model implementation in tensorflow. It works perfectly fine in the link after running the following lines:

import tensorflow as tf

tf.config.run_functions_eagerly(True)

Here's the CTC custom layer as well as accuracy calculation function:

def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
    y_pred = tf.stack(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)
    unknown_indices = tf.where(y_pred == -1)
    y_pred = tf.tensor_scatter_nd_update(
        y_pred,
        unknown_indices,
        tf.cast(tf.ones(unknown_indices.shape[0]) * unknown_placeholder, tf.int64),
    )
    if metric == 'word':
        return tf.where(tf.reduce_all(y_true == y_pred, 1)).shape[0] / y_true.shape[0]
    if metric == 'char':
        return tf.where(y_true == y_pred).shape[0] / tf.reduce_prod(y_true.shape)
    return 0


class CTCLayer(Layer):
    def __init__(self, max_label_length, unknown_placeholder, **kwargs):
        super().__init__(**kwargs)
        self.max_label_length = max_label_length
        self.unknown_placeholder = unknown_placeholder

    def call(self, *args):
        y_true, y_pred = args
        batch_length = tf.cast(tf.shape(y_true)[0], dtype='int64')
        input_length = tf.cast(tf.shape(y_pred)[1], dtype='int64')
        label_length = tf.cast(tf.shape(y_true)[1], dtype='int64')
        input_length = input_length * tf.ones(shape=(batch_length, 1), dtype='int64')
        label_length = label_length * tf.ones(shape=(batch_length, 1), dtype='int64')
        loss = tf.keras.backend.ctc_batch_cost(
            y_true, y_pred, input_length, label_length
        )
        if y_true.shape[1] is not None:  # this is to prevent an error at model creation
            predictions = decode_batch_predictions(y_pred, self.max_label_length)
            self.add_metric(
                calculate_accuracy(
                    y_true, predictions, 'word', self.unknown_placeholder
                ),
                'word_accuracy',
            )
            self.add_metric(
                calculate_accuracy(
                    y_true, predictions, 'char', self.unknown_placeholder
                ),
                'char_accuracy',
            )
        self.add_loss(loss)
        return y_pred

The if y_true.shape[1] is not None block is meant to prevent an error which will happen when the model is being created, because a placeholder is passed instead of the actual tensor. Here's what happens if the if statement is not present (eager execution or not, I still get the same error)

3 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    697       except Exception as e:  # pylint:disable=broad-except
    698         if hasattr(e, 'ag_error_metadata'):
--> 699           raise e.ag_error_metadata.to_exception(e)
    700         else:
    701           raise

ValueError: Exception encountered when calling layer "ctc_loss" (type CTCLayer).

in user code:

    File "<ipython-input-6-fabf4ec5a640>", line 67, in call  *
        predictions = decode_batch_predictions(y_pred, self.max_label_length)
    File "<ipython-input-6-fabf4ec5a640>", line 23, in decode_batch_predictions  *
        results = tf.keras.backend.ctc_decode(
    File "/usr/local/lib/python3.7/dist-packages/keras/backend.py", line 6436, in ctc_decode
        inputs=y_pred, sequence_length=input_length)

    ValueError: Shape must be rank 1 but is rank 0 for '{{node ctc_loss/CTCGreedyDecoder}} = CTCGreedyDecoder[T=DT_FLOAT, blank_index=-1, merge_repeated=true](ctc_loss/Log_1, ctc_loss/Cast_9)' with input shapes: [31,?,20], [].


Call arguments received:
  • args=('tf.Tensor(shape=(None, None), dtype=float32)', 'tf.Tensor(shape=(None, 31, 20), dtype=float32)')

Note: In graph execution, the shape of labels is always (None, None), so the code under if block which adds the metrics, is never executed. To make metrics work, just run the notebook I included without modifications and modify it later for reproducing the error.

Here's what you should see when eager execution is enabled:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py:4527: UserWarning: Even though the `tf.config.experimental_run_functions_eagerly` option is set, this option does not apply to tf.data functions. To force eager execution of tf.data functions, please use `tf.data.experimental.enable_debug_mode()`.
  "Even though the `tf.config.experimental_run_functions_eagerly` "
Epoch 1/100
     59/Unknown - 42s 177ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04
Epoch 00001: val_loss improved from inf to 17.36043, saving model to 1k_captcha.tf

59/59 [==============================] - 44s 213ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04 - val_loss: 17.3604 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 2/100
59/59 [==============================] - ETA: 0s - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021
Epoch 00002: val_loss improved from 17.36043 to 16.20875, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 210ms/step - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021 - val_loss: 16.2087 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 3/100
59/59 [==============================] - ETA: 0s - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110
Epoch 00003: val_loss improved from 16.20875 to 16.11712, saving model to 1k_captcha.tf

59/59 [==============================] - 12s 204ms/step - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110 - val_loss: 16.1171 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0071
Epoch 4/100
59/59 [==============================] - ETA: 0s - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184
Epoch 00004: val_loss did not improve from 16.11712

59/59 [==============================] - 12s 207ms/step - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184 - val_loss: 16.6811 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0143
Epoch 5/100
59/59 [==============================] - ETA: 0s - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225
Epoch 00005: val_loss improved from 16.11712 to 15.23923, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 214ms/step - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225 - val_loss: 15.2392 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0268
Epoch 6/100
59/59 [==============================] - ETA: 0s - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258
Epoch 00006: val_loss did not improve from 15.23923

59/59 [==============================] - 12s 207ms/step - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258 - val_loss: 18.6373 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0089
Epoch 7/100
59/59 [==============================] - ETA: 0s - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335
Epoch 00007: val_loss improved from 15.23923 to 14.37547, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 215ms/step - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335 - val_loss: 14.3755 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 8/100
59/59 [==============================] - ETA: 0s - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422
Epoch 00008: val_loss did not improve from 14.37547

59/59 [==============================] - 13s 208ms/step - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422 - val_loss: 14.4376 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 9/100
59/59 [==============================] - ETA: 0s - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780
Epoch 00009: val_loss did not improve from 14.37547

59/59 [==============================] - 13s 211ms/step - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780 - val_loss: 14.8398 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0500
Epoch 10/100
59/59 [==============================] - ETA: 0s - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460
Epoch 00010: val_loss did not improve from 14.37547

59/59 [==============================] - 13s 215ms/step - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460 - val_loss: 14.4219 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1054
Epoch 11/100
59/59 [==============================] - ETA: 0s - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004
Epoch 00011: val_loss improved from 14.37547 to 10.11944, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 212ms/step - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004 - val_loss: 10.1194 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1750
Epoch 12/100
59/59 [==============================] - ETA: 0s - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388
Epoch 00012: val_loss did not improve from 10.11944

59/59 [==============================] - 13s 216ms/step - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388 - val_loss: 10.3900 - val_word_accuracy: 0.0089 - val_char_accuracy: 0.1714
Epoch 13/100
59/59 [==============================] - ETA: 0s - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047
Epoch 00013: val_loss improved from 10.11944 to 8.38430, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 215ms/step - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047 - val_loss: 8.3843 - val_word_accuracy: 0.0179 - val_char_accuracy: 0.2714
Epoch 14/100
59/59 [==============================] - ETA: 0s - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519
Epoch 00014: val_loss did not improve from 8.38430

59/59 [==============================] - 13s 217ms/step - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519 - val_loss: 9.5689 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.2571
Epoch 15/100
59/59 [==============================] - ETA: 0s - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271
Epoch 00015: val_loss improved from 8.38430 to 6.74445, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 214ms/step - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271 - val_loss: 6.7445 - val_word_accuracy: 0.0804 - val_char_accuracy: 0.3482
Epoch 16/100
59/59 [==============================] - ETA: 0s - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799
Epoch 00016: val_loss improved from 6.74445 to 5.40682, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 222ms/step - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799 - val_loss: 5.4068 - val_word_accuracy: 0.1161 - val_char_accuracy: 0.4446
Epoch 17/100
59/59 [==============================] - ETA: 0s - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206
Epoch 00017: val_loss improved from 5.40682 to 4.76755, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 220ms/step - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206 - val_loss: 4.7676 - val_word_accuracy: 0.1964 - val_char_accuracy: 0.4929
Epoch 18/100
59/59 [==============================] - ETA: 0s - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712
Epoch 00018: val_loss improved from 4.76755 to 4.45828, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 221ms/step - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712 - val_loss: 4.4583 - val_word_accuracy: 0.2768 - val_char_accuracy: 0.5375
Epoch 19/100
59/59 [==============================] - ETA: 0s - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267
Epoch 00019: val_loss improved from 4.45828 to 4.13174, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 222ms/step - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267 - val_loss: 4.1317 - val_word_accuracy: 0.2054 - val_char_accuracy: 0.5143
Epoch 20/100
59/59 [==============================] - ETA: 0s - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979
Epoch 00020: val_loss improved from 4.13174 to 3.35257, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 223ms/step - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979 - val_loss: 3.3526 - val_word_accuracy: 0.3482 - val_char_accuracy: 0.5518
Epoch 21/100
59/59 [==============================] - ETA: 0s - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284
Epoch 00021: val_loss did not improve from 3.35257

59/59 [==============================] - 13s 223ms/step - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284 - val_loss: 3.5486 - val_word_accuracy: 0.3304 - val_char_accuracy: 0.5500
Epoch 22/100
59/59 [==============================] - ETA: 0s - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021
Epoch 00022: val_loss improved from 3.35257 to 2.97987, saving model to 1k_captcha.tf

59/59 [==============================] - 14s 229ms/step - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021 - val_loss: 2.9799 - val_word_accuracy: 0.3750 - val_char_accuracy: 0.6679
Epoch 23/100
59/59 [==============================] - ETA: 0s - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417
Epoch 00023: val_loss did not improve from 2.97987

59/59 [==============================] - 13s 224ms/step - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417 - val_loss: 5.2543 - val_word_accuracy: 0.1429 - val_char_accuracy: 0.4768
Epoch 24/100
59/59 [==============================] - ETA: 0s - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576
Epoch 00024: val_loss improved from 2.97987 to 2.72415, saving model to 1k_captcha.tf

59/59 [==============================] - 14s 226ms/step - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576 - val_loss: 2.7242 - val_word_accuracy: 0.4911 - val_char_accuracy: 0.7250
Epoch 25/100
59/59 [==============================] - ETA: 0s - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837
Epoch 00025: val_loss improved from 2.72415 to 2.47315, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 223ms/step - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837 - val_loss: 2.4731 - val_word_accuracy: 0.4554 - val_char_accuracy: 0.6964
Epoch 26/100
59/59 [==============================] - ETA: 0s - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195
Epoch 00026: val_loss improved from 2.47315 to 2.10521, saving model to 1k_captcha.tf

59/59 [==============================] - 14s 227ms/step - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195 - val_loss: 2.1052 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.6929
Epoch 27/100
59/59 [==============================] - ETA: 0s - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528
Epoch 00027: val_loss did not improve from 2.10521

59/59 [==============================] - 14s 226ms/step - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528 - val_loss: 2.5292 - val_word_accuracy: 0.4375 - val_char_accuracy: 0.7054
Epoch 28/100
59/59 [==============================] - ETA: 0s - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500
Epoch 00028: val_loss did not improve from 2.10521

59/59 [==============================] - 14s 224ms/step - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500 - val_loss: 2.1713 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.7268

To reproduce the problem, if you ran the notebook before, you may need to restart runtime, then try running without eager execution, and the metrics will never show. If you want to reproduce the error, comment out the line if y_true.shape[1] is not None and merge the if block with the rest of the code. What do I need to modify in the provided notebook to make the metrics work as demonstrated without having to use eager execution?


Solution

  • You may not like this kind of solution, but you could try changing your calculate_accuracy and decode_batch_predictions functions so that they only use tf operations:

    def decode_batch_predictions(predictions, max_label_length, char_lookup=None, increment=0):
        input_length = tf.cast(tf.ones(tf.shape(predictions)[0]), dtype=tf.int32) * tf.cast(tf.shape(predictions)[1], dtype=tf.int32)
        
        results = tf.keras.backend.ctc_decode(
            predictions, input_length=input_length, greedy=True
        )[0][0][:, :max_label_length] + increment
    
        if char_lookup: # For inference
          output = []
          for result in results:
            result = tf.strings.reduce_join(char_lookup(result)).numpy().decode('utf-8')
            output.append(result)
          return output
        else: # For training
          output = tf.TensorArray(tf.int64, size=0, dynamic_size=True)
          for result in results:
            output = output.write(output.size(), result)
          return output.stack()
    
    def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
        y_pred = tf.stack(y_pred)
        y_true = tf.cast(y_true, y_pred.dtype)
        unknown_indices = tf.where(y_pred == -1)
        y_pred = tf.tensor_scatter_nd_update(
            y_pred,
            unknown_indices,
            tf.cast(tf.ones(tf.shape(unknown_indices)[0]) * unknown_placeholder, tf.int64),
        )
        if metric == 'word':
            return tf.shape(tf.where(tf.reduce_all(y_true == y_pred, 1)))[0] / tf.shape(y_true)[0]
        if metric == 'char':
            return tf.shape(tf.where(y_true == y_pred))[0] / tf.reduce_prod(tf.shape(y_true))
        return 0
    
    Writing example: 936/1040 [90.0 %] to e7fe398b-da12-4176-a91c-84a8ca076937-train.tfrecord
    Writing example: 1040/1040 [100.0 %] to e7fe398b-da12-4176-a91c-84a8ca076937-valid.tfrecord
    Epoch 1/100
         59/Unknown - 107s 470ms/step - loss: 18.2176 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0015
    Epoch 00001: val_loss improved from inf to 16.23781, saving model to 1k_captcha.tf
    

    enter image description here

    This way you do not have to use tf.config.run_functions_eagerly(True) or if y_true.shape[1] is not None.