azuretensorflowmodel

Load model from model.weights.h5 file stored in Azure Blob


I am trying to load my model from file I saved in Azure blob storage.

This is the function that downloads the file to my local machine:

def download_weights_to_temp_file():
    """
    Downloads the model weights from Azure Blob Storage to a temporary file.
    Returns the local path to the temporary file.
    """
    try:
        # Authenticate with Azure
        default_credential = DefaultAzureCredential()
        blob_service_client = BlobServiceClient(account_url, credential=default_credential)
        blob_client = blob_service_client.get_blob_client(container=CONTAINER_NAME_MODEL, blob=MODEL_BLOB_NAME)
        
        print(f"Downloading model weights from: {blob_client.url}")

        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as temp_file:
            temp_file.write(blob_client.download_blob().readall())
            temp_model_path = temp_file.name

        print(f"Model weights downloaded to: {temp_model_path}")

        if os.path.exists(temp_model_path):
            print(f"Downloaded file exists at: {temp_model_path}")
        else:
            print("Error: Downloaded file does not exist.")

        return temp_model_path
    except Exception as e:
        print(f"Error downloading model weights: {e}")
        raise

And this is the function that loads my model:

def load_model():
    temp_model_path = download_weights_to_temp_file()
    try:
        # Create the model architecture
        model= create_enet_model(input_shape=(256, 256, 1), num_classes=1)
        model.compile(optimizer="adam", loss=weighted_binary_crossentropy, metrics=["accuracy"])
        model.load_weights(temp_model_path)
        print(f"Successfully loaded model weights from: enet_light_fold_4.weights.h5")

    except Exception as e:
        print(f"Error loading model weights: {e}")
        raise
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_model_path):
            os.remove(temp_model_path)

    return model

However, I get following error:

File "C:\Users\Amanda\anaconda3\envs\tf17\lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler raise e.with_traceback(filtered_tb) from None File "C:\Users\Amanda\anaconda3\envs\tf17\lib\site-packages\keras\src\legacy\saving\legacy_h5_format.py", line 357, in load_weights_from_hdf5_group raise ValueError( ValueError: Layer count mismatch when loading weights from file. Model expected 13 layers, found 0 saved layers.

But if I run my code with the original file I have, it works fine. I imagine it is something when reading the file from Azure but nothing GPT sugested worked here.

I expect to be able to load my model.


Solution

  • The model weights you are downloading will be saved to the file ending with .h5 instead it should save it as .weights.h5, this is the reason you are getting error.

    Change the suffix whiling saving the while.

    with tempfile.NamedTemporaryFile(delete=False, suffix=".weights.h5") as temp_file:
                temp_file.write(blob_client.download_blob().readall())
                temp_model_path = temp_file.name
    

    Output:

    enter image description here