pythontensorflowkerasdeep-learningunet-neural-network

Issues trying to load saved Keras U-Net model from h5 file


I've been assigned a task in my company to try to hydrate a model that was trained for a previous project, and while I can load it again, I'm failing to try it and I don't know why.

The model follows a U-Net architecture, and here's the output of the summary() method after calling load_weights().

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 640, 512, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 640, 512, 64  640         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 640, 512, 64  36928       ['conv2d[0][0]']                 
                                )                                                                 
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 320, 256, 64  0           ['conv2d_1[0][0]']               
                                )                                                                 
                                                                                                  
 conv2d_2 (Conv2D)              (None, 320, 256, 12  73856       ['max_pooling2d[0][0]']          
                                8)                                                                
                                                                                                  
 conv2d_3 (Conv2D)              (None, 320, 256, 12  147584      ['conv2d_2[0][0]']               
                                8)                                                                
                                                                                                  
 max_pooling2d_1 (MaxPooling2D)  (None, 160, 128, 12  0          ['conv2d_3[0][0]']               
                                8)                                                                
                                                                                                  
 conv2d_4 (Conv2D)              (None, 160, 128, 25  295168      ['max_pooling2d_1[0][0]']        
                                6)                                                                
                                                                                                  
 conv2d_5 (Conv2D)              (None, 160, 128, 25  590080      ['conv2d_4[0][0]']               
                                6)                                                                
                                                                                                  
 max_pooling2d_2 (MaxPooling2D)  (None, 80, 64, 256)  0          ['conv2d_5[0][0]']               
                                                                                                  
 conv2d_6 (Conv2D)              (None, 80, 64, 512)  1180160     ['max_pooling2d_2[0][0]']        
                                                                                                  
 conv2d_7 (Conv2D)              (None, 80, 64, 512)  2359808     ['conv2d_6[0][0]']               
                                                                                                  
 max_pooling2d_3 (MaxPooling2D)  (None, 40, 32, 512)  0          ['conv2d_7[0][0]']               
                                                                                                  
 conv2d_8 (Conv2D)              (None, 40, 32, 1024  4719616     ['max_pooling2d_3[0][0]']        
                                )                                                                 
                                                                                                  
 conv2d_9 (Conv2D)              (None, 40, 32, 1024  9438208     ['conv2d_8[0][0]']               
                                )                                                                 
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 80, 64, 1024  0           ['conv2d_9[0][0]']               
                                )                                                                 
                                                                                                  
 concatenate (Concatenate)      (None, 80, 64, 1536  0           ['up_sampling2d[0][0]',          
                                )                                 'conv2d_7[0][0]']               
                                                                                                  
 conv2d_10 (Conv2D)             (None, 80, 64, 512)  7078400     ['concatenate[0][0]']            
                                                                                                  
 conv2d_11 (Conv2D)             (None, 80, 64, 512)  2359808     ['conv2d_10[0][0]']              
                                                                                                  
 up_sampling2d_1 (UpSampling2D)  (None, 160, 128, 51  0          ['conv2d_11[0][0]']              
                                2)                                                                
                                                                                                  
 concatenate_1 (Concatenate)    (None, 160, 128, 76  0           ['up_sampling2d_1[0][0]',        
                                8)                                'conv2d_5[0][0]']               
                                                                                                  
 conv2d_12 (Conv2D)             (None, 160, 128, 25  1769728     ['concatenate_1[0][0]']          
                                6)                                                                
                                                                                                  
 conv2d_13 (Conv2D)             (None, 160, 128, 25  590080      ['conv2d_12[0][0]']              
                                6)                                                                
                                                                                                  
 up_sampling2d_2 (UpSampling2D)  (None, 320, 256, 25  0          ['conv2d_13[0][0]']              
                                6)                                                                
                                                                                                  
 concatenate_2 (Concatenate)    (None, 320, 256, 38  0           ['up_sampling2d_2[0][0]',        
                                4)                                'conv2d_3[0][0]']               
                                                                                                  
 conv2d_14 (Conv2D)             (None, 320, 256, 12  442496      ['concatenate_2[0][0]']          
                                8)                                                                
                                                                                                  
 conv2d_15 (Conv2D)             (None, 320, 256, 12  147584      ['conv2d_14[0][0]']              
                                8)                                                                
                                                                                                  
 up_sampling2d_3 (UpSampling2D)  (None, 640, 512, 12  0          ['conv2d_15[0][0]']              
                                8)                                                                
                                                                                                  
 concatenate_3 (Concatenate)    (None, 640, 512, 19  0           ['up_sampling2d_3[0][0]',        
                                2)                                'conv2d_1[0][0]']               
                                                                                                  
 conv2d_16 (Conv2D)             (None, 640, 512, 64  110656      ['concatenate_3[0][0]']          
                                )                                                                 
                                                                                                  
 conv2d_17 (Conv2D)             (None, 640, 512, 64  36928       ['conv2d_16[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_18 (Conv2D)             (None, 640, 512, 1)  65          ['conv2d_17[0][0]']              
                                                                                                  
==================================================================================================
Total params: 31,377,793
Trainable params: 31,377,793
Non-trainable params: 0
__________________________________________________________________________________________________

My main concern is that when I load a picture as an numpy array, ending up with an input shaped (640, 512, 1), just like the first layer, I get the following error.

from tensorflow.keras.preprocessing.image import load_img, img_to_array

img_size = (640,512)
color_mode = "grayscale" 
image = img_to_array(load_img(image_path, target_size=self.image_size, color_mode=self.color_mode))
image = image/255.0
print(image.shape)
#(640, 512, 1)

#unet is a wrapper class which contains the model described above
#unet.load_weights('../../models/unet_model_i_04.h5')
unet.model.predict(image)

This snippet produces this error:

ValueError: Input 0 of layer "model" is incompatible with the layer: expected shape=(None, 640, 512, 1), found shape=(32, 512, 1)

I tried changing the shape of the input, and hence the image size, (I suspect, based on an outdated notebook I was granted access to, that when the model was initially trained and exported, it was done using 320x256 images, but that only changes the error to

ValueError: Input 0 of layer "model" is incompatible with the layer: expected shape=(None, 320, 256, 1), found shape=(32, 256, 1)

Solution

  • As someone suggested in the comments, try expanding the dimension of the input tensor. The model expect a batch of images as input, and for one image, the input shape should be 1 x H x W x 1:

    from tensorflow.keras.preprocessing.image import load_img, img_to_array
    
    img_size = (640,512)
    color_mode = "grayscale" 
    image = img_to_array(load_img(image_path, target_size=imge_size, color_mode=color_mode))
    image = image/255.0
    image = tf.expand_dims(image, 0)
    print(image.shape)
    #(1, 640, 512, 1)
    
    #unet is a wrapper class which contains the model described above
    #unet.load_weights('../../models/unet_model_i_04.h5')
    unet.model.predict(image)