pythontensorflowmachine-learningkerastensorflow2.0

Keras model.export() fails because of NoneType shapes in model


I am trying to finetune a DeepLabV3Plus model from the keras_cv library to use on a custom dataset, but upon attempting to export to SavedModel format, I get this error:

File "C:\Users\u\.pyenv\pyenv-win\versions\3.10.2\lib\site-packages\keras\src\utils\traceback_utils.py", line 731, in error_handler  *
    return fn(*args, **kwargs)

TypeError: Exception encountered when calling UpSampling2D.call().

unsupported operand type(s) for *: 'NoneType' and 'int'

Arguments received by UpSampling2D.call():
  • inputs=tf.Tensor(shape=(None, None, None, 256), dtype=float32)

This is confusing, as I have specified the input shape for my model when calling DeepLabV3Plus.from_preset as [224,224,3], but the model summary shows None shapes for all layers (see here for model.summary() output). However, from notebooks I have seen, this is the intended behavior even when you specify input shape.

As for the training script, this is the code I used:

model = keras_cv.models.DeepLabV3Plus.from_preset(
    "mobilenet_v3_large_imagenet",
    num_classes=NUM_CLASSES,
    input_shape=[224,224,3],
    load_weights=True
)

layers_to_train = 1
def disable_training(x): x.trainable = False
[disable_training(layer) for layer in model.layers[:-layers_to_train]]
model.summary()


model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=[keras.losses.CategoricalFocalCrossentropy(from_logits=False, alpha=class_weights, gamma=3)],
    metrics=[keras.metrics.OneHotMeanIoU(num_classes=NUM_CLASSES), 'accuracy'])
 
callback_cyclic = CyclicLR(base_lr = LEARNING_RATE, max_lr = MAX_LEARNING_RATE, step_size=STEP_SIZE, mode = CYCLIC_MODE)

history = model.fit(train_dataset, epochs=NUM_EPOCHS, batch_size=NUM_BATCH, validation_data=val_dataset, callbacks=[callback_cyclic])


model.export(savepath_dir+"model.tf")

Solution

  • Solved using the solution in this GitHub issue for keras_cv (adapting input size to my own). https://github.com/keras-team/keras-cv/issues/2455