I am trying to load my model from file I saved in Azure blob storage.
This is the function that downloads the file to my local machine:
def download_weights_to_temp_file():
"""
Downloads the model weights from Azure Blob Storage to a temporary file.
Returns the local path to the temporary file.
"""
try:
# Authenticate with Azure
default_credential = DefaultAzureCredential()
blob_service_client = BlobServiceClient(account_url, credential=default_credential)
blob_client = blob_service_client.get_blob_client(container=CONTAINER_NAME_MODEL, blob=MODEL_BLOB_NAME)
print(f"Downloading model weights from: {blob_client.url}")
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as temp_file:
temp_file.write(blob_client.download_blob().readall())
temp_model_path = temp_file.name
print(f"Model weights downloaded to: {temp_model_path}")
if os.path.exists(temp_model_path):
print(f"Downloaded file exists at: {temp_model_path}")
else:
print("Error: Downloaded file does not exist.")
return temp_model_path
except Exception as e:
print(f"Error downloading model weights: {e}")
raise
And this is the function that loads my model:
def load_model():
temp_model_path = download_weights_to_temp_file()
try:
# Create the model architecture
model= create_enet_model(input_shape=(256, 256, 1), num_classes=1)
model.compile(optimizer="adam", loss=weighted_binary_crossentropy, metrics=["accuracy"])
model.load_weights(temp_model_path)
print(f"Successfully loaded model weights from: enet_light_fold_4.weights.h5")
except Exception as e:
print(f"Error loading model weights: {e}")
raise
finally:
# Clean up the temporary file
if os.path.exists(temp_model_path):
os.remove(temp_model_path)
return model
However, I get following error:
File "C:\Users\Amanda\anaconda3\envs\tf17\lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler raise e.with_traceback(filtered_tb) from None File "C:\Users\Amanda\anaconda3\envs\tf17\lib\site-packages\keras\src\legacy\saving\legacy_h5_format.py", line 357, in load_weights_from_hdf5_group raise ValueError( ValueError: Layer count mismatch when loading weights from file. Model expected 13 layers, found 0 saved layers.
But if I run my code with the original file I have, it works fine. I imagine it is something when reading the file from Azure but nothing GPT sugested worked here.
I expect to be able to load my model.
The model weights you are downloading will be saved to the file ending with .h5
instead it should save it as .weights.h5
, this is the reason you are getting error.
Change the suffix whiling saving the while.
with tempfile.NamedTemporaryFile(delete=False, suffix=".weights.h5") as temp_file:
temp_file.write(blob_client.download_blob().readall())
temp_model_path = temp_file.name
Output: