pythontensorflowkerassemantic-segmentationdeeplab

In need of a Labelencoder code to fix the InvalidArgumentError: Graph execution error when Using deeplabv3+ for a 4 class semantic segmentation


Am running a 4 class semantic segmentation problem using Deeplabv3+ and I get the graph execution error as soon as the training starts. I have identified the problem after searching the web for solutions. The problem is with the labels. My labels are 2, 4, 6, 8 instead of 0,1,2,3. Currently the model trains after adjusting to num_classes: 9 so as to cater for all the labels. sounds weird since I should encode the labels to 0,1,2,3. I have failed to encode the labels to to fit in my main code so as to get rid of the invalid labels. Can some help me with the script and where to fix it in my code? Thank you :)

This is the code that i generated and i will be glad if anyone goes through it

import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
import albumentations as A
from glob import glob
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
path = "D:/K/deeplabv3plus/datasets"
os.chdir(path)

os.listdir(os.path.join('128_patches', 'masks'))[:924]
os.listdir(os.path.join('128_patches', 'images'))[:924]

config = {
    'IMG_PATH': os.path.join('128_patches', 'images'),
    'LABEL_PATH': os.path.join('128_patches', 'masks'),
    'NUM_CLASSES': 4,
    'BATCH_SIZE': 2,
    'IMAGE_SIZE': 128
}

##### Building Dataset
image_paths =  glob(os.path.join(config['IMG_PATH'], '*'), recursive=True)
mask_paths =  glob(os.path.join(config['LABEL_PATH'], '*'), recursive=True)

#image_paths_train, image_paths_test, mask_paths_train, mask_paths_test = train_test_split(image_paths, mask_paths, shuffle=True)
image_paths_train1, image_paths_test1, mask_paths_train1, mask_paths_test1 = train_test_split(image_paths, mask_paths, test_size=0.15)
image_paths_train, image_paths_test, mask_paths_train, mask_paths_test = train_test_split(image_paths_train1, mask_paths_train1, test_size=0.15)

config['DATASET_LENGTH'] = len(image_paths_train)
def preprocess(image_path, mask_path):
img = tf.io.read_file(image_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.resize(img, size=[config['IMAGE_SIZE'], config['IMAGE_SIZE']])
img = tf.cast(img, tf.float32) / 255.0
    
mask = tf.io.read_file(mask_path)
# Only one channel for masks, denoting the class and NOT image colors
mask = tf.image.decode_png(mask, channels=1)
mask = tf.image.resize(mask, size=[config['IMAGE_SIZE'], config['IMAGE_SIZE']])
mask = tf.cast(mask, tf.float32)
return img, mask

def augment_dataset_tf(img, mask):
     #  Augmentations should always be performed on both an input image and a mask if applied at all
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_up_down(img)
        mask = tf.image.flip_up_down(mask)
    if tf.random.uniform(()) > 0.5:
        img = tf.image.rot90(img)
        mask = tf.image.rot90(mask)
            
    return img, mask

def albumentations(img, mask):
    # Augmentation pipeline - each of these has an adjustable probability
    # of being applied, regardless of other transforms
    transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.3),
        A.Transpose(p=0.5),
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=70),
        # CoarseDropout is the new Cutout implementation
        A.CoarseDropout(p=0.5, max_holes=12, max_height=24, max_width=24)
    ])
    
    # Apply transforms and extract image and mask
    transformed = transform(image=img, mask=mask)
    transformed_image = transformed['image']
    transformed_mask = transformed['mask']
    
    # Cast to TF Floats and return
    transformed_image = tf.cast(transformed_image, tf.float32)
    transformed_mask = tf.cast(transformed_mask, tf.float32)
    return transformed_image, transformed_mask

def create_dataset_tf(images, masks, augment):
    dataset = tf.data.Dataset.from_tensor_slices((images, masks)).shuffle(len(images))
    dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    if augment:
        dataset = dataset.map(apply_albumentations, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(config['BATCH_SIZE'], drop_remainder=True).prefetch(tf.data.AUTOTUNE).repeat()
    else:
        dataset = dataset.batch(config['BATCH_SIZE'], drop_remainder=True).prefetch(tf.data.AUTOTUNE)
    return dataset

def apply_albumentations(img, mask):
    aug_img, aug_mask = tf.numpy_function(func=albumentations, inp=[img, mask], Tout=[tf.float32, tf.float32])
    aug_img = tf.ensure_shape(aug_img, shape=[config['IMAGE_SIZE'], config['IMAGE_SIZE'], 3])
    aug_mask = tf.ensure_shape(aug_mask, shape=[config['IMAGE_SIZE'], config['IMAGE_SIZE'], 1])
    return aug_img, aug_mask

train_set = create_dataset_tf(image_paths_train, mask_paths_train, augment=False)
test_set = create_dataset_tf(image_paths_test, mask_paths_test, augment=False)

for img_batch, mask_batch in train_set.take(2):
    for i in range(len(img_batch)):
        fig, ax = plt.subplots(1, 2)
        ax[0].imshow(img_batch[i].numpy())
        ax[1].imshow(mask_batch[i].numpy())
        
# Turns into atrous_block with dilation_rate > 1
def conv_block(block_input, num_filters=256, kernel_size=(3, 3), dilation_rate=1, padding="same"):
    x = keras.layers.Conv2D(num_filters, kernel_size=kernel_size, dilation_rate=dilation_rate, padding="same")(block_input)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)
    return x

# Atrous Spatial Pyramid Pooling
def ASPP(inputs):
    # 4 conv blocks with dilation rates at `[1, 6, 12, 18]`
    conv_1 = conv_block(inputs, kernel_size=(1, 1), dilation_rate=1)
    conv_6 = conv_block(inputs, kernel_size=(3, 3), dilation_rate=6)
    conv_12 = conv_block(inputs, kernel_size=(3, 3), dilation_rate=12)
    conv_18 = conv_block(inputs, kernel_size=(3, 3), dilation_rate=18)
    
    dims = inputs.shape
    # Image Pooling -> (256, 256, 3) -> (1, 1, filter_num) -> (32, 32, 256)
    x = keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(inputs)
    x = conv_block(x, kernel_size=1)
    out_pool = keras.layers.UpSampling2D(size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]))(x)

    x = keras.layers.Concatenate()([conv_1, conv_6, conv_12, conv_18, out_pool])
    return conv_block(x, kernel_size=1)

def define_deeplabv3_plus(image_size, num_classes, backbone):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    
    if backbone == 'resnet':
        resnet101 = keras.applications.ResNet152(
            weights="imagenet", 
            include_top=False, 
            input_tensor=model_input)
        x = resnet101.get_layer("conv4_block6_2_relu").output
        low_level = resnet101.get_layer("conv2_block3_2_relu").output
        
    elif backbone == 'effnet':
        effnet = keras.applications.EfficientNetV2B1(
            weights="imagenet", 
             include_top=False, 
             input_tensor=model_input)
        x = effnet.get_layer("block5e_activation").output
        low_level = effnet.get_layer("block2a_expand_activation").output
        
    aspp_result = ASPP(x)
    upsampled_aspp = keras.layers.UpSampling2D(size=(4, 4))(aspp_result)
    
    low_level = conv_block(low_level, num_filters=48, kernel_size=1)

    x = keras.layers.Concatenate()([upsampled_aspp, low_level])
    x = conv_block(x)
    x = keras.layers.UpSampling2D(size=(4, 4))(x)
    model_output = keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same", activation='softmax')(x)
    return keras.Model(inputs=model_input, outputs=model_output)



#model = define_deeplabv3_plus(config['IMAGE_SIZE'], config['NUM_CLASSES'], 'resnet')
model = define_deeplabv3_plus(config['IMAGE_SIZE'], config['NUM_CLASSES'], 'effnet')
model.summary()

from keras import backend as K

def dice_coef(y_true, y_pred, smooth=1e-7):
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=config['NUM_CLASSES'])[...,1:])
    y_pred_f = K.flatten(y_pred[...,1:])
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

class MeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
        super(MeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.math.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)
    
reduceLr = keras.callbacks.ReduceLROnPlateau(patience=5, factor=0.3, monitor='val_sparse_categorical_accuracy')
early_stopping = keras.callbacks.EarlyStopping(patience=10, monitor='val_sparse_categorical_accuracy', restore_best_weights=True)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    #loss=soft_dice_loss,
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy", 
              MeanIoU(num_classes=config['NUM_CLASSES']),
              dice_coef])

history = model.fit(train_set, 
                    epochs=100, 
                    steps_per_epoch=int(config['DATASET_LENGTH']/config['BATCH_SIZE']), 
                    validation_data=test_set,
                    callbacks=[reduceLr, early_stopping])

Now, below is the error that comes up the moment i begin to train the model

Epoch 1/100

2022-10-17 09:57:16.605349: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8500
Traceback (most recent call last):

  File "C:\Users\Windows\AppData\Local\Temp\ipykernel_10008\204393382.py", line 12, in <module>
    history = model.fit(train_set,

  File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None

  File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\tensorflow\python\eager\execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError: Graph execution error:

Detected at node 'confusion_matrix/assert_less/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\spyder_kernels\console\__main__.py", line 24, in <module>
      start.main()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\spyder_kernels\console\start.py", line 332, in main
      kernel.start()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelapp.py", line 677, in start
      self.io_loop.start()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\tornado\platform\asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\asyncio\base_events.py", line 600, in run_forever
      self._run_once()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\asyncio\base_events.py", line 1896, in _run_once
      handle._run()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 367, in dispatch_shell
      await result
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 2914, in run_cell
      result = self._run_cell(
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 2960, in _run_cell
      return runner(coro)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 3185, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 3377, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\Windows\AppData\Local\Temp\ipykernel_10008\204393382.py", line 12, in <module>
      history = model.fit(train_set,
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1051, in train_function
      return step_function(self, iterator)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 894, in train_step
      return self.compute_metrics(x, y, y_pred, sample_weight)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 987, in compute_metrics
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\compile_utils.py", line 501, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\utils\metrics_utils.py", line 70, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\metrics\base_metric.py", line 140, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "C:\Users\Windows\AppData\Local\Temp\ipykernel_10008\1378111743.py", line 12, in update_state
      return super().update_state(y_true, y_pred, sample_weight)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\metrics\metrics.py", line 2494, in update_state
      current_cm = tf.math.confusion_matrix(
Node: 'confusion_matrix/assert_less/Assert/AssertGuard/Assert'
Detected at node 'confusion_matrix/assert_less/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\spyder_kernels\console\__main__.py", line 24, in <module>
      start.main()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\spyder_kernels\console\start.py", line 332, in main
      kernel.start()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelapp.py", line 677, in start
      self.io_loop.start()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\tornado\platform\asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\asyncio\base_events.py", line 600, in run_forever
      self._run_once()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\asyncio\base_events.py", line 1896, in _run_once
      handle._run()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 367, in dispatch_shell
      await result
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\ipykernel\zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 2914, in run_cell
      result = self._run_cell(
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 2960, in _run_cell
      return runner(coro)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 3185, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 3377, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\Windows\AppData\Local\Temp\ipykernel_10008\204393382.py", line 12, in <module>
      history = model.fit(train_set,
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1051, in train_function
      return step_function(self, iterator)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 894, in train_step
      return self.compute_metrics(x, y, y_pred, sample_weight)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\training.py", line 987, in compute_metrics
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\engine\compile_utils.py", line 501, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\utils\metrics_utils.py", line 70, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\metrics\base_metric.py", line 140, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "C:\Users\Windows\AppData\Local\Temp\ipykernel_10008\1378111743.py", line 12, in update_state
      return super().update_state(y_true, y_pred, sample_weight)
    File "C:\Users\Windows\anaconda3\envs\Kal\lib\site-packages\keras\metrics\metrics.py", line 2494, in update_state
      current_cm = tf.math.confusion_matrix(
Node: 'confusion_matrix/assert_less/Assert/AssertGuard/Assert'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  assertion failed: [`labels` out of bound] [Condition x < y did not hold element-wise:] [x (confusion_matrix/control_dependency:0) = ] [6 6 6...] [y (confusion_matrix/Cast_2:0) = ] [4]
     [[{{node confusion_matrix/assert_less/Assert/AssertGuard/Assert}}]]
     [[confusion_matrix/assert_less_1/Assert/AssertGuard/pivot_f/_31/_61]]
  (1) INVALID_ARGUMENT:  assertion failed: [`labels` out of bound] [Condition x < y did not hold element-wise:] [x (confusion_matrix/control_dependency:0) = ] [6 6 6...] [y (confusion_matrix/Cast_2:0) = ] [4]
     [[{{node confusion_matrix/assert_less/Assert/AssertGuard/Assert}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_17948]

I will be very glad if anyone can help with the debug.


Solution

  • You can change your label by using

    dataset = dataset.map(lambda x, y: (x, y/2-1), num_parallel_calls=tf.data.AUTOTUNE)
    

    inside the function create_dataset_tf. It will map the label from 2, 4, 6, 8 to 0, 1, 2, 3.