pythontensorflowkerascallbackimage-classification

Evaluate model on Testing Set after each epoch of training


I'm training a tensorflow model on image dataset for a classification task, we usually provide the training set and validation set to the model.fit method, we can later output model convergence graph of training and validation. I want to do the same with the testing set, in other words, I want to get the accuracy and loss of my model on the testing set after each epoch(not validation set - and I can't replace the validation set with the testing set because I need graphs of both of them).

I managed to do that by saving the checkpoints of my model after each epoch using some callback and later load each checkpoint to my model and compute accuracy and loss, but I want to know if there exists some easier way of doing that, maybe with some other callback or some work around with the model.fit method.


Solution

  • You could use a custom Callback and pass your test data and do whatever you like:

    import tensorflow as tf
    import pathlib
    import numpy as np
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 5
    
    train_ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      seed=123,
      image_size=(64, 64),
      batch_size=batch_size)
    
    test_ds = train_ds.take(30)
    
    model = tf.keras.Sequential([
      tf.keras.layers.Rescaling(1./255, input_shape=(64, 64, 3)),
      tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(5)
    ])
    
    class TestCallback(tf.keras.callbacks.Callback):
        def __init__(self, test_dataset):
            super().__init__()
            self.test_dataset = test_dataset
            self.test_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
            self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 
    
        def on_epoch_end(self, epoch, logs=None):
            losses = []
            for x_batch_test, y_batch_test in self.test_dataset:
              test_logits = self.model(x_batch_test, training=False)
              losses.append(self.loss_fn(y_batch_test, test_logits))
              self.test_acc_metric.update_state(y_batch_test, test_logits)
            test_acc = self.test_acc_metric.result()
            self.test_acc_metric.reset_states()
            logs['test_loss'] = tf.reduce_mean(tf.stack(losses))  # not sure if the reduction is correct
            logs['test_sparse_categorical_accuracy'] = test_acc
    
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 
    model.compile(optimizer='adam',
                  loss=loss_fn,
                  metrics=tf.keras.metrics.SparseCategoricalAccuracy())
    epochs = 5
    history = model.fit(train_ds, epochs=epochs, callbacks= [TestCallback(test_ds)])
    
    Found 3670 files belonging to 5 classes.
    Epoch 1/5
    734/734 [==============================] - 14s 17ms/step - loss: 1.2709 - sparse_categorical_accuracy: 0.4591 - test_loss: 1.0020 - test_sparse_categorical_accuracy: 0.5533
    Epoch 2/5
    734/734 [==============================] - 13s 18ms/step - loss: 0.9574 - sparse_categorical_accuracy: 0.6275 - test_loss: 0.8348 - test_sparse_categorical_accuracy: 0.6467
    Epoch 3/5
    734/734 [==============================] - 9s 12ms/step - loss: 0.8136 - sparse_categorical_accuracy: 0.6733 - test_loss: 0.8379 - test_sparse_categorical_accuracy: 0.6467
    Epoch 4/5
    734/734 [==============================] - 8s 11ms/step - loss: 0.6970 - sparse_categorical_accuracy: 0.7357 - test_loss: 0.5713 - test_sparse_categorical_accuracy: 0.7533
    Epoch 5/5
    734/734 [==============================] - 8s 11ms/step - loss: 0.5793 - sparse_categorical_accuracy: 0.7834 - test_loss: 0.5656 - test_sparse_categorical_accuracy: 0.7733
    

    You can also just use model.evaluate in the callback. See also this post.