openmdao

Loading KrigingSurrogate Trained Model without Storing Training Data


openmdao's KrigingSurrogate allows the user to cache a trained Kriging surrogate model and load it later using the optional argument training_cache. This works great except for one sometimes inconvenient feature - `KrigingSurrogate always checks the training data in the trained model against the provided training data to make sure they are the same before loading the trained model. Otherwise, the model will be retrained with the new training data. Unfortunately, this seems to require the user to separately pickle the training data, both inputs and outputs, if they want to train the model in one script and then load it in another.

Is there any way to skip the training data validation and instead use the training data that is already saved in the trained model?

My current method for creating a Kriging model in one script and then loading it in another looks like this:

# create_model.py

import numpy as np
import openmdao.api as om
import pickle

x = np.arange(0, 11, 1)
y = x**2

surrogate = om.MetaModelUnStructuredComp()
surrogate.add_input('x', training_data = x)
surrogate.add_output('y', training_data = y, surrogate = om.KrigingSurrogate(training_cache = 'surrogate.dat'))

prob = om.Problem()
prob.model.add_subsystem('surrogate', surrogate)
prob.setup()
prob.run_model() # trains model, saves to surrogate.dat

training_data = {
    'x': x,
    'y': y
}

# pickle training data
with open('training_data.pickle', 'wb') as training_data_file:
    pickle.dump(training_data, training_data_file)
# load_model.py

import numpy as np
import openmdao.api as om
import pickle

'''
I want to skip this because the training data is already saved with the model I am about to load,
but I can't because KrigingSurrogate requires training data to check the saved model against.
'''
with open('training_data.pickle', 'rb') as training_data_file:
    training_data = pickle.load(training_data_file)

x = training_data['x']
y = training_data['y']

surrogate = om.MetaModelUnStructuredComp()
surrogate.add_input('x', training_data = x)
surrogate.add_output('y', training_data = y, surrogate = om.KrigingSurrogate(training_cache = 'surrogate.dat'))

prob = om.Problem()
prob.model.add_subsystem('surrogate', surrogate)
prob.setup()
prob.run_model() # loads trained model

Solution

  • As of OpenMDAO V3.25 the kriging models still requires a cache validation against the training data. In theory it would be a nice improvement to have the surrogate model include that data into the cache it stored and then reload it from the same file. This would save you the extra step of pickling it.

    The problem is there would be no great way for the surrogate to know if the input training data had changed for a necessary reason or not. A user might set new training data, then be surprised when it got overwritten by the data in the cache. Maybe you could throw a warning, but I've found that many users ignore warnings :(

    If you want to customize the behavior, you can make your own version of the KrigigingSurrogate class and remove that validation. It only requires removing a few lines of code.

    If you can think of a decent update to sort out the problem with knowing if the cache is valid and re-loading the inputs (without stomping on any user provided inputs) feel free to submit a POEM. Otherwise, just make your own surrogate and comment out the cache validity check (and be careful!)