So I am using the example of the vision transformer model for image classification provided on the Keras website. The only difference is I have added a line to save the model once it is done training as a ".keras" file.
Later I try to load the saved model and check it's configuration using "get_configuration()".
Lmodel=load_model("VITexp.keras")
Lmodel.get_config()
But the code fails to load the model and gives me the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:208, in Operation.from_config(cls, config)
207 try:
--> 208 return cls(**config)
209 except Exception as e:
TypeError: Patches.__init__() got an unexpected keyword argument 'name'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
717 try:
--> 718 instance = cls.from_config(inner_config)
719 except TypeError as e:
File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:210, in Operation.from_config(cls, config)
209 except Exception as e:
--> 210 raise TypeError(
211 f"Error when deserializing class '{cls.__name__}' using "
212 f"config={config}.\n\nException encountered: {e}"
213 )
TypeError: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.
Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
717 try:
--> 718 instance = cls.from_config(inner_config)
719 except TypeError as e:
File /opt/conda/lib/python3.10/site-packages/keras/src/models/model.py:517, in Model.from_config(cls, config, custom_objects)
515 from keras.src.models.functional import functional_from_config
--> 517 return functional_from_config(
518 cls, config, custom_objects=custom_objects
519 )
521 # Either the model has a custom __init__, or the config
522 # does not contain all the information necessary to
523 # revive a Functional model. This happens when the user creates
(...)
526 # In this case, we fall back to provide all config into the
527 # constructor of the class.
File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:517, in functional_from_config(cls, config, custom_objects)
516 for layer_data in config["layers"]:
--> 517 process_layer(layer_data)
519 # Then we process nodes in order of layer depth.
520 # Nodes that cannot yet be processed (if the inbound node
521 # does not yet exist) are re-enqueued, and the process
522 # is repeated until all nodes are processed.
File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:501, in functional_from_config.<locals>.process_layer(layer_data)
500 else:
--> 501 layer = serialization_lib.deserialize_keras_object(
502 layer_data, custom_objects=custom_objects
503 )
504 created_layers[layer_name] = layer
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
719 except TypeError as e:
--> 720 raise TypeError(
721 f"{cls} could not be deserialized properly. Please"
722 " ensure that components that are Python object"
723 " instances (layers, models, etc.) returned by"
724 " `get_config()` are explicitly deserialized in the"
725 " model's `from_config()` method."
726 f"\n\nconfig={config}.\n\nException encountered: {e}"
727 )
728 build_config = config.get("build_config", None)
TypeError: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.
Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.
Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[11], line 1
----> 1 Lmodel=load_model("VITexp.keras")
2 Lmodel.get_config()
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_api.py:176, in load_model(filepath, custom_objects, compile, safe_mode)
173 is_keras_zip = True
175 if is_keras_zip:
--> 176 return saving_lib.load_model(
177 filepath,
178 custom_objects=custom_objects,
179 compile=compile,
180 safe_mode=safe_mode,
181 )
182 if str(filepath).endswith((".h5", ".hdf5")):
183 return legacy_h5_format.load_model_from_hdf5(
184 filepath, custom_objects=custom_objects, compile=compile
185 )
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:152, in load_model(filepath, custom_objects, compile, safe_mode)
147 raise ValueError(
148 "Invalid filename: expected a `.keras` extension. "
149 f"Received: filepath={filepath}"
150 )
151 with open(filepath, "rb") as f:
--> 152 return _load_model_from_fileobj(
153 f, custom_objects, compile, safe_mode
154 )
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:170, in _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode)
168 # Construct the model from the configuration file in the archive.
169 with ObjectSharingScope():
--> 170 model = deserialize_keras_object(
171 config_dict, custom_objects, safe_mode=safe_mode
172 )
174 all_filenames = zf.namelist()
175 if _VARS_FNAME + ".h5" in all_filenames:
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
718 instance = cls.from_config(inner_config)
719 except TypeError as e:
--> 720 raise TypeError(
721 f"{cls} could not be deserialized properly. Please"
722 " ensure that components that are Python object"
723 " instances (layers, models, etc.) returned by"
724 " `get_config()` are explicitly deserialized in the"
725 " model's `from_config()` method."
726 f"\n\nconfig={config}.\n\nException encountered: {e}"
727 )
728 build_config = config.get("build_config", None)
729 if build_config and not instance.built:
TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
Exception encountered: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.
Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.
Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'
The code is copy pasted from the website except for the save and load model commands.
Please help me solve this. Is there a specific way to save these models to be accessed later in completely different notebooks? (I am using Kaggle for this code)
To sum up my comments under the question and give a comprehensive answer, below the code. I used the code from the link in the question, added lines by me are marked by # comments. Only the layer classes have to be modified.
@keras.saving.register_keras_serializable() # <- this line
class Patches(layers.Layer):
def __init__(self, patch_size, **kwargs): # <- add **kwargs
super().__init__(**kwargs) # <- add **kwargs
self.patch_size = patch_size
def call(self, images):
input_shape = ops.shape(images)
batch_size = input_shape[0]
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patches = keras.ops.image.extract_patches(images, size=self.patch_size)
patches = ops.reshape(
patches,
(
batch_size,
num_patches_h * num_patches_w,
self.patch_size * self.patch_size * channels,
),
)
return patches
def get_config(self):
config = super().get_config()
config.update({"patch_size": self.patch_size})
return config
# ------------------------------------------------------------------
@keras.saving.register_keras_serializable() # this line
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim, **kwargs): # <- add **kwargs
super().__init__(**kwargs) # <- add **kwargs
self.num_patches = num_patches
self.projection_dim = projection_dim # save projection_dim
print(f'num_patches: {num_patches}, proj. dim: {projection_dim}')
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def build(self, input_shape): # add build method (this threw only a warning)
super().build(input_shape)
def call(self, patch):
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
projected_patches = self.projection(patch)
encoded = projected_patches + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
config.update({"projection_dim": self.projection_dim}) # this line
return config
Short explanation of the added code lines:
@keras.saving.register_keras_serializable()
This decorator registers the custom layer for Keras to know about, registering it in a master list.
**kwargs
Catch unknown (to the user) keyword arguments the __init__()
method gets and give them to the super()
call. In this case, __init__()
got the parameter name
, because every Layer
class gets one. But name
was initially not an expected parameter.
self.projection_dim = projection_dim
# ...
config.update({"projection_dim": self.projection_dim})
This two lines save projection_dim
to the config of the PatchEncoder
layer. This is done to use the set parameter when loading the layer again.