pythontensorflowkerasresuming-training

Keras - manage history


I am training Keras models, saving them with model.save() and than later loading them and resuming training.

I would like to plot after each training the whole training history, but model.fit_generator() only returns the history of the last session of training.

I can save the history of the initial session and update it myself, but I wonder if there is a standard way in Keras of managing the training history.

history1 = model.fit_generator(my_gen)
plot_history(history1)
model.save('my_model.h5')

# Some days afterwards...

model = load_model('my_model.h5')
history2 = model.fit_generator(my_gen)

# here I would like to reconstruct the full_training history
# including the info from history1 and history2
full_history = ???

Solution

  • Use numpy to concatenate the specific history keys that you are interested in.

    For example, let's say these are your two training runs:

    history1 = model.fit_generator(my_gen)
    history2 = model.fit_generator(my_gen)
    

    You can view the dictionary keys, which will be labeled the same for each run by:

    print(history1.history.keys())
    

    This will print the an output like:

    dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
    

    Then you can use numpy concatenate. For example:

    import numpy as np
    combo_accuracy=np.concatenate((history1.history['accuracy'],history2.history['accuracy']),axis=0)
    combo_val_accuracy=np.concatenate((history1.history['val_accuracy'],history2.history['val_accuracy']),axis=0)
    

    You can then plot the concatenated history arrays with matplotlib:

    import matplotlib.pyplot as plt
    plt.plot(combo_acc, 'orange', label='Training accuracy')
    plt.plot(combo_val_acc, 'blue', label='Validation accuracy')
    plt.legend()