pythontensorflowmatplotlibjupyter-notebookduplicates

How to remove duplicate matplotlib plot


I am trying to develop a small feature that plots dynamically the loss or accuracy during the training of a Tensorflow model. I basically plot the history of accuracies at the end of each batch processing for each epoch (the code still needs some corrections but it works correctly for now).

I have a small problem, as I run the following code in a jupyter notebook cell. I have the desired behavior, with a plot that evolves dynamically. However at the end of the training the final plot is duplicated for some reason and I can't figure out why it is the case.

from IPython.display import display, clear_output
import tensorflow as tf
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt


class CustomCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.epoch = 0  # Initialize the epoch counter
        self.accuracies = []
        self.fig, self.ax = plt.subplots()
        self.line, = self.ax.plot([], [])
        self.ax.set_xlim(0, 30)
        self.ax.set_ylim(0, 1)
        self.displayed = False
        display(self.fig)
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch  # Update the current epoch at the beginning of each epoch

    def on_train_batch_end(self, batch, logs=None):
        accuracy = logs['accuracy']
        self.accuracies.append(accuracy)
        self.line.set_data(range(1, len(self.accuracies) + 1), self.accuracies)
        self.ax.relim()
        self.ax.autoscale_view()
        clear_output(wait=True)
        display(self.fig)


custom_callback = CustomCallback()

model = Sequential()
model.add(tf.keras.layers.Dense(units=16, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.35))
model.add(tf.keras.layers.Dense(units=1, activation='tanh'))

model.compile(optimizer=tf.keras.optimizers.Adam(), loss="binary_crossentropy", metrics=["accuracy"])

X = np.random.randn(10**2, 10**4)
y = np.random.randint(2, size=10**2)

abc = model.fit(X, y, epochs=7, batch_size=32, validation_split=0.025, verbose=False, callbacks=[custom_callback])

Solution

  • It's because jupyter notebook already shows a figure inline, so calls to display() is duplicating it. For example, the following code shows the same line plot twice in jupyter notebook.

    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    ax.plot(range(3))
    display(fig)
    

    To turn off the interactive mode, call plt.ioff() right after matplotlib import. Alternatively, you can also close the figure at the end of training by including the following method to the class.

        def on_train_end(self, logs=None):
            plt.close(self.fig)