tensorflowkerasdeep-learningtensorflow-datasetsmultitasking

Cascaded Convolutional Neural Network - multi-input and multi-output with TensorFlow API


I am trying to implement a cascaded model, proposed in this paper, but faced some issue in the data loading pipelines. The overal architecute of that model is shown below:

The data subsets are: road: (img, mask_road) and centerline: (the same img, mask_centerline). To build up the data loading pipelines, I created these inputs using the tf.data.Dataset API, i.e., read the data, decode and convert in tensor [0,1]. So, to train the model, I tried to zip my inputs, like:

zip_train = tf.data.Dataset.zip((dataset_train,center_train))
zip_valid = tf.data.Dataset.zip((dataset_val,center_val))
zip_test = tf.data.Dataset.zip((dataset_test,center_test))

In the model figure above, for the centerline extraction (network 2), the input is the last feature map from network 1 (road detection) and the respective subset. Therefore, I concatenated them during the model definition (please see the code attached gist file below). When I tried to run the code, the error indicates that my layer of concatenation is incompatible, in particular about the channels, as follows:

history = model.fit(
      zip_train, 
      epochs=epochs, 
      steps_per_epoch=steps, 
      validation_data=zip_valid, 
      callbacks=callbacks
  )

Input 0 of layer "d1_11_conv" is incompatible with the layer:
expected axis -1 of input shape to have value 67, 
but received input with shape (None, 512, 512, 65)

Call arguments received by layer 'model_12' (type Functional):

  • inputs=('tf.Tensor(shape=(None, 512, 512, 3), dtype=float32)',
   'tf.Tensor(shape=(None, 512, 512, 1), dtype=float32)')
  • training=True
  • mask=None

How to resolve this? Here is the reproducible code that contains the cascaded model definition and data loading pipelines.


Solution

  • Let's first summarize the whole picture. The model you're trying to build and run is kind of two auto-encoder model, desgined to solve two task simultaneously. So, if we pass an input image, the model would give two output, i.e. road map and centerline map. But first, we need to trian this model with the given dataset, where a image and corresponding road and centerline segmentation mask are present. In simple term, we can frame this problem as a semantic segmentation with 1 input and 2 outputs.

    To build up the training data loader using tf.data API for such task is quite straightforward. However, there can be different approach but the overall setup would be same. About the error you faced regarding concatenating layers, I think that is expected to happen. But according to the figure from the paper, I think you don't need to do that. You can simply pass the feature maps of first network to next. Let's build this project step by step. I'm using TF 2.11, testing on kaggle with P100 GPUs.

    Model

    Some common blocks of layers.

    def ConvBlock(filters, kernel, kernel_initializer, activation, name=None):
        
        if name is None:
            name = "ConvBlock" + str(backend.get_uid("ConvBlock"))
        
        def apply(input):
            c1 = layers.Conv2D(
                filters=filters,
                kernel_size=kernel,
                padding='same',
                kernel_initializer=kernel_initializer,
                name=name+'_conv'
            )(input)
            c1 = layers.BatchNormalization(name=name+'_batch')(c1)
            c1 = layers.Activation(activation,name=name+'_active')(c1)
            return c1
        
        return apply
    
    def DownConvBlock(filters, kernel, kernel_initializer, activation, name=None):
        
        if name is None:
            name = "DownConvBlock" + str(backend.get_uid("DownConvBlock"))
        
        def apply(input):
            d1 = layers.Conv2DTranspose(
                filters=filters,
                kernel_size=kernel,
                padding='same',
                kernel_initializer=kernel_initializer,
                name=name+'_conv'
            )(input)
            d1 = layers.BatchNormalization(name=name+'_batch')(d1)
            d1 = layers.Activation(activation,name=name+'_active')(d1)
            return d1
        
        return apply
    

    Sub-model for road mask detection task.

    def network_mask(input, activation, kernel_initializer, kernel_size):
        # Network 1
        # ENCODER
        x = input
        for fmap in [64, 128, 256, 512]:
            x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            x = layers.MaxPool2D(pool_size=(2,2), strides=None, padding='same')(x)
    
        # DECODER   
        for fmap in [512, 256, 128, 64]:
            x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
            x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
    
        x = layers.Conv2D(
                filters=1, 
                kernel_size=(1,1),
                kernel_initializer=kernel_initializer,
                activation=None,
        )(x)
        
        return x
    

    Sub-model for centerline mask detection task.

    def network_centerline(input, activation, kernel_initializer, kernel_size):
        # Network 2
        # ENCODER
        x = input
        for fmap in [64, 128, 256]:
            x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            x = layers.MaxPool2D(pool_size=(2,2), strides=None, padding='same')(x)
    
        # DECODER   
        for fmap in [256, 128, 64]:
            x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
            x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
            
        x = layers.Conv2DTranspose(
            filters=1, 
            kernel_size=(1,1), 
            kernel_initializer=kernel_initializer,
            activation=None, 
        )(x)
        
        return x
    

    Full cascaded network, i.e. CasNet.

    def CasNet(activation, kernel_initializer, kernel_size):
        input = keras.Input(shape=(img_size, img_size, channel), name='images')
        
        mask_feat = network_mask(input, activation, kernel_initializer, kernel_size)
        centerline_feat = network_centerline(
            mask_feat, activation, kernel_initializer, kernel_size
        )
        
        mask_op = keras.layers.Activation(
            'sigmoid', name='mask', dtype=tf.float32
        )(mask_feat)
        centerline_op = keras.layers.Activation(
            'sigmoid', name='centerline', dtype=tf.float32
        )(centerline_feat)
        
        model = keras.Model(
            inputs={
                'images': input
            },
            outputs={
                'mask': mask_op,
                'centerline': centerline_op
            },
            name='CasNet'
        )
        return model
    

    Data Loader

    Augmentation pipelines in keras. In coming days, we can use keras-cv for this.

    set_seed = 101
    rand_flip = layers.RandomFlip("horizontal_and_vertical", seed=set_seed)
    rand_rote = layers.RandomRotation(factor=0.01, seed=set_seed)
    # more: https://keras.io/api/layers/preprocessing_layers/image_augmentation/
    
    def keras_augment(image, label, centerline):
        tensors =  tf.concat([image, label, centerline], axis=-1)
        
        def apply_augment(x):
            x = rand_flip(x)
            x = rand_rote(x)
            return x
        
        aug_tensors = apply_augment(tensors)
        image, label, centerline = tf.split(aug_tensors, [3, 1, 1], axis=-1)
        return image, label, centerline
    

    Load the samples (road, mask, centerline). The pixel value of road image is normal RGB color, ranging from 0~255. And the pixel value of road mask and road centerline are ranging between 0-255 with 3 color channel. We will normalize this values.

    def read_files(image_path, mask=False):
        image = tf.io.read_file(image_path)
        if mask:
            image = tf.io.decode_png(image, channels=1, dtype=tf.uint8)
            image = tf.image.resize(
                images=image, 
                size=[img_size, img_size], 
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
            )
            image = tf.where(image == 255, 1, 0)
            image = tf.cast(image, tf.float32)
        else:
            image = tf.io.decode_png(image, channels=3, dtype=tf.uint8)
            image = tf.image.resize(images=image, size=[img_size, img_size])
            image = tf.cast(image, tf.float32)
            image = image / 255.
        return image
    
    def load_data(image_list, label_list, centerline_list): 
        image = read_files(image_list)
        label = read_files(label_list, mask=True)
        center = read_files(centerline_list, mask=True)
        return image, label, center
    

    Notice here, how we pack (prepare_dict method below) the data for single input and multi-output. Same thing could be done for multi-input and multi-output or multi-input and single output, etc. Again, as mentioned, there could be different way to load such dataset using the same API but the overall setup would be same. I don't want to mention the possible alternatives to avoid confusion.

    def prepare_dict(image_batch, label_batch, centerline_batch):
        return {'images': image_batch}, {'mask':label_batch, 'centerline':centerline_batch}
    
    def dataloader(image_list, label_list, center_list, split='train'):
        dataset = tf.data.Dataset.from_tensor_slices(
            (image_list, label_list, center_list)
        )
        dataset = dataset.shuffle(batch_size * 8) if split == 'train' else dataset
        dataset = dataset.repeat() if split == 'train' else dataset
        dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.map(keras_augment) if split == 'train' else dataset
        dataset = dataset.map(prepare_dict, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.prefetch(buffer_size=AUTOTUNE)
        return dataset
    

    download

    download

    Compile and Run

    Let compile the model with loss and metrics and fit it. For loss and metrics, we will use this library until the keras-cv is ready for segmentation task. See the loss and metrics parameters below, we are passing loss and metric function for both of the output of the model. Though we can simply pass one loss/metric method and that would be used for both ouput but it's nice to know that we can pass loss/metric method in such format.

    model.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=0.0001
        ),
        loss={
            'mask':sm.losses.bce_jaccard_loss,
            'centerline': sm.losses.binary_focal_jaccard_loss
        },
        metrics={
            'mask': sm.metrics.iou_score,
            'centerline': sm.metrics.f1_score
        }
    )
    
    history = model.fit(
        train_ds, 
        validation_data=valid_ds,
        steps_per_epoch=len(train_images_path) // batch_size,
        callbacks=my_callbacks,
        epochs=epoch
    )
    
    ...
    ...
    160/160 [==============================] - 186s
    loss: 1.0082 - centerline_loss: 0.7613 - mask_loss: 0.2469 -
    centerline_f1-score: 0.4074 - mask_iou_score: 0.8115 - 
    val_loss: 1.2867 - val_centerline_loss: 0.7986 - 
    val_mask_loss: 0.4882 - val_centerline_f1-score: 0.3572 - 
    val_mask_iou_score: 0.6860
    
    160/160 [==============================] - 186s 1s/step - 
    loss: 0.9827 - centerline_loss: 0.7491 - mask_loss: 0.2336 - 
    centerline_f1-score: 0.4223 - mask_iou_score: 0.8210 - 
    val_loss: 1.4251 - val_centerline_loss: 0.8222 - 
    val_mask_loss: 0.6028 - val_centerline_f1-score: 0.3160 - 
    val_mask_iou_score: 0.6344
    ...
    ...
    

    download

    download

    download

    Full Code and Resources

    Here is the full code, it's run on kaggle (P100, TF 2.11). Here are some resource that might come handy. Most of them are related to segmentation modeling and about loss method selection.