pythontensorflowkerasloss-function

In Keras, how can I save and load a neural network model that includes a custom loss function?


I am having difficulty saving and reloading a neural network model when I use a custom loss function. For example, in the code below (which integrates the suggestions of the related questions here and here), "Save/Load Attempt 0" works without errors, while "Save/Load Attempt 1" does not, returning the cryptic error TypeError: string indices must be integers, not 'str' regardless of whether the model is loaded with the parameters compile=False or custom_objects={'loss': custom_loss}. How can I modify "Save/Load Attempt 1" to be successful?

import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
from keras.models import Sequential, load_model
from keras.layers import Input, Dense
from keras import ops

path = 'C:/Users/.../AppData/Local/Programs/Python/Python312/.../'  # The ... represent single folders.

# ----------------------------------------------------------------------------------------------------
# Save/Load Attempt 0

dnn = Sequential()
dnn.add(Input(shape=(3,)))
dnn.add(Dense(units=5, activation='relu'))
dnn.add(Dense(units=1))
dnn.compile(loss='mean_absolute_error', optimizer='adam')

model_path_0 = path + 'dnn_0.h5'
dnn.save(model_path_0)

dnn = load_model(model_path_0)

# ----------------------------------------------------------------------------------------------------
print('---')
# Save/Load Attempt 1

def custom_loss(y_true, y_pred):
    squared_difference = ops.square(y_true - y_pred)
    return ops.mean(squared_difference, axis=-1)  # flattens squared_difference

dnn = Sequential()
dnn.add(Input(shape=(3,)))
dnn.add(Dense(units=5, activation='relu'))
dnn.add(Dense(units=1))
dnn.compile(loss=custom_loss, optimizer='adam')

model_path_1 = path + 'dnn_1.h5'
dnn.save(model_path_1)

dnn = load_model(model_path_1, custom_objects={'loss': custom_loss})
# dnn = load_model(model_path_1, compile=False)

# ----------------------------------------------------------------------------------------------------

For reference, the traceback of the error is as follows.

Traceback (most recent call last):
  File "c:\Users\...\AppData\Local\Programs\Python\Python312\...\test_load_model.py", line 43, in <module>
    dnn = load_model(model_path_1, compile=False)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\...\AppData\Local\Programs\Python\Python312\Lib\site-packages\keras\src\saving\saving_api.py", line 183, in load_model
    return legacy_h5_format.load_model_from_hdf5(filepath)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\...\AppData\Local\Programs\Python\Python312\Lib\site-packages\keras\src\legacy\saving\legacy_h5_format.py", line 155, in load_model_from_hdf5
    **saving_utils.compile_args_from_training_config(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\...\AppData\Local\Programs\Python\Python312\Lib\site-packages\keras\src\legacy\saving\saving_utils.py", line 145, in compile_args_from_training_config
    loss = _resolve_compile_arguments_compat(loss, loss_config, losses)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\...\AppData\Local\Programs\Python\Python312\Lib\site-packages\keras\src\legacy\saving\saving_utils.py", line 245, in _resolve_compile_arguments_compat
    obj = module.get(obj_config["config"]["name"])
                     ~~~~~~~~~~^^^^^^^^^^
TypeError: string indices must be integers, not 'str'

Solution. Following stateMachine's answer, I resolved the error with the following three changes:

  1. I uninstalled Keras 3.1.1 (pip uninstall keras) and installed 3.5.0 (pip install keras).
  2. I changed the extensions I was using for saving and loading from .h5 to the newer .keras.
  3. I set compile=False in the load_model command and instead compiled the model with a dnn.compile command after loading it.

Solution

  • Are you compiling the model after re-loading it? You don't mention what version of keras you are using, however, I wrote a toy example tested using tensorflow and keras 2.14.0 as well as tensorflow 2.17.0 and keras 3.5.0:

    import numpy as np
    import tensorflow as tf
    
    from keras import Sequential
    from keras.layers import Input, Dense
    
    # Model architecture:
    model = Sequential()
    model.add(Input(shape=(100,)))
    model.add(Dense(units=5, activation='relu'))
    model.add(Dense(units=1))
    
    # Custom Loss
    def custom_loss(y_true, y_pred):
        squared_difference = tf.math.square(y_true - y_pred)
        return tf.reduce_mean(squared_difference, axis=-1)  
    
    # Compile model:
    model.compile(optimizer="adam", loss=custom_loss, metrics= ["mean_squared_error"])
    
    # Show summary:
    model.summary()
    
    # Some random inputs/targets:
    x=np.random.rand(300,100)
    y=np.random.rand(300,2)
    
    # Fit the model for 5 epochs:
    model.fit(x,y,batch_size=100, epochs=5)
    
    # Save model
    path = 'saved_model/myModel.keras'
    model.save(path)
    print("Model saved")
    
    # Load model:
    del model
    model = tf.keras.models.load_model(path, compile=False, custom_objects={"custom_loss": custom_loss})
    print("Model reloaded")
    
    # Compile:
    model.compile(optimizer="adam", loss=custom_loss, metrics= ["mean_squared_error"])
    
    # Continue training for 5 more epochs:
    model.fit(x, y, batch_size=100, epochs=5)
    print("Done fitting")
    

    Output:

    Model: "sequential_6"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     dense_12 (Dense)            (None, 5)                 505       
                                                                     
     dense_13 (Dense)            (None, 1)                 6         
                                                                     
    =================================================================
    Total params: 511 (2.00 KB)
    Trainable params: 511 (2.00 KB)
    Non-trainable params: 0 (0.00 Byte)
    _________________________________________________________________
    Epoch 1/5
    3/3 [==============================] - 1s 8ms/step - loss: 0.5925 - mean_squared_error: 0.5925
    Epoch 2/5
    3/3 [==============================] - 0s 10ms/step - loss: 0.4522 - mean_squared_error: 0.4522
    Epoch 3/5
    3/3 [==============================] - 0s 8ms/step - loss: 0.3607 - mean_squared_error: 0.3607
    Epoch 4/5
    3/3 [==============================] - 0s 10ms/step - loss: 0.2928 - mean_squared_error: 0.2928
    Epoch 5/5
    3/3 [==============================] - 0s 15ms/step - loss: 0.2372 - mean_squared_error: 0.2372
    Model saved
    Model reloaded
    Epoch 1/5
    3/3 [==============================] - 1s 6ms/step - loss: 0.2142 - mean_squared_error: 0.2142
    Epoch 2/5
    3/3 [==============================] - 0s 4ms/step - loss: 0.1967 - mean_squared_error: 0.1967
    Epoch 3/5
    3/3 [==============================] - 0s 4ms/step - loss: 0.1827 - mean_squared_error: 0.1827
    Epoch 4/5
    3/3 [==============================] - 0s 5ms/step - loss: 0.1659 - mean_squared_error: 0.1659
    Epoch 5/5
    3/3 [==============================] - 0s 4ms/step - loss: 0.1528 - mean_squared_error: 0.1528
    Done fitting
    

    Some notes: The recommended format for model saving is using the .keras extension. I just assumed here that this is a regression problem, but it doesn't really matter since the data is random and we are just testing if the model is being saved/reloaded properly, just need to check that the metrics indeed pickup where training first left off.

    When dealing with custom objects, I personally prefer sub-classing and the explicit inclusion of the get_config() method. config is a Python dictionary (serializable) containing the information needed to re-instantiate it. This makes things generally tidier and easy to maintain. See: https://keras.io/guides/serialization_and_saving/ specifically the "Custom objects" section.