tensorflowkerastensorflow2.0keras-2

Keras predict/predict_on_batch giving different answers than predict_step/__call__()


From my understanding, all four of these methods: predict, predict_on_batch, predict_step, and a direct forward pass through the model (e.g. model(x, training=False) or __call__()) should all give the same results, some are just more efficient than others in how they handle batches of data versus one sample.

But I am actually getting different results on an image super-resolution (upscaling) task I'm working on:

for lowres, _ in val.take(1):
    # Get a randomly cropped region of the lowres image for upscaling
    lowres = tf.image.random_crop(lowres, (150, 150, 3))  # uint8
    
    # Need to add a dummy batch dimension for the predict step    
    model_inputs = tf.expand_dims(lowres, axis=0)  # (1, 150, 150, 3), uint8
    
    # And convert the uint8 image values to float32 for input to the model
    model_inputs = tf.cast(model_inputs, tf.float32)  # float32
    
    preds = model.predict_on_batch(model_inputs)
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)
    
    preds = model.predict(model_inputs)
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)
    
    preds = model.predict_step(model_inputs)
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)
    
    preds = model(model_inputs, training=False)  # __call__()
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)
    

Prints:

Min value:  -6003.622
Max value:  5802.6826

Min value:  -6003.622
Max value:  5802.6826

Min value:  -53.7696
Max value:  315.1499

Min value:  -53.7696
Max value:  315.1499

Both predict_step and __call__() give the "correct" answers as defined by the upscaled images look correct.

I'm happy to share more details on the model if that's helpful, but for now I thought I'd just leave it at this to not overcomplicate the question. At first I wondered if these methods had different results based on training/inference modes, but my model doesn't use any BatchNorm or Dropout layers, so that shouldn't make a difference here. It's completely composed of: Conv2D, Add, tf.nn.depth_to_space (pixel shuffle), and Rescaling layers. That's it. It also doesn't use any subclassing or override any methods, just uses keras.Model(inputs, outputs).

Any ideas why these prediction methods would give different answers?

EDIT: I've been able to create a minimally reproducible example where you can see the issue. Please see: https://www.kaggle.com/code/quackaddict7/really-minimum-reproducible-example

I initially couldn't reproduce the problem in a minimal example. I eventually added back in a dataset, batching, data augmentation, training, model file saving/restoring, and eventually discovered the issue is GPU vs. CPU! So I took all that back out for my minimal example. If you run the notebook attached you'll see that on CPU, all four methods give the same answer with randomly initialized weights. But if you change to P100 GPU, predict/predict_on_batch differ from predict_step/forward pass (__call__).

So I guess at this point, my question is, why are CPU vs. GPU results different here?


Solution

  • I have tested the given sample code in tf.keras==2.12.0 and found a possible bug in the API and it fails only on GPU. In your sample code, the mismatch occurred due to the relu activation. If we set anything else, i.e. selu or elu or even leaky_relu, they would work as expected.

    def ResBlock(inputs):
        x = layers.Conv2D(64, 3, padding="same")(inputs)
        x = layers.Conv2D(64, 3, padding="same")(x)
        x = layers.Add()([inputs, x])
        return x
    

    In order to keep using relu method, following fix can be adopted for the moment.

    def relu(x):
        return keras.backend.maximum(x, 0)
    
    def ResBlock(inputs):
        x = layers.Conv2D(64, 3, padding="same")(inputs)
        x = keras.layers.Lambda(relu, output_shape=lambda shape: shape)(x)
        x = layers.Conv2D(64, 3, padding="same")(x)
        x = layers.Add()([inputs, x])
        return x
    

    Here is the full code for references.

    Model

    def relu(x):
        return keras.backend.maximum(x, 0)
    
    def ResBlock(inputs):
        x = layers.Conv2D(64, 3, padding="same")(inputs)
        x = keras.layers.Lambda(
            relu, output_shape=lambda shape: shape
        )(x)
        x = layers.Conv2D(64, 3, padding="same")(x)
        x = layers.Add()([inputs, x])
        return x
    
    def Upsampling(inputs, factor=2, **kwargs):
        x = layers.Conv2D(
            64 * (factor ** 2), 3, padding="same", **kwargs
        )(inputs)
        x = tf.nn.depth_to_space(x, block_size=factor)
        x = layers.Conv2D(
            64 * (factor ** 2), 3, padding="same", **kwargs
        )(x)
        x = tf.nn.depth_to_space(x, block_size=factor)
        return x
    
    def make_model(num_filters, num_of_residual_blocks):
        input_layer = layers.Input(shape=(None, None, 3))
        x = layers.Rescaling(scale=1.0 / 255)(input_layer)
        x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)
    
        for i in range(num_of_residual_blocks):
            x_new = ResBlock(x_new)
    
        x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
        x = layers.Add()([x, x_new])
        x = Upsampling(x)
        x = layers.Conv2D(3, 3, padding="same")(x)
        
        output_layer = layers.Rescaling(scale=255)(x)
        return keras.Model(input_layer, output_layer)
    

    Inference

    lowres = tf.random.uniform(
        shape=(150, 150, 3), 
        minval=0, 
        maxval=256, 
        dtype='float32'
    )
    model_inputs = tf.expand_dims(lowres, axis=0)
    
    predict_out = model.predict(model_inputs)
    predict_on_batch_out = model.predict_on_batch(model_inputs)
    predict_call_out = model(model_inputs, training=False).numpy()
    predict_step_out = model.predict_step(model_inputs).numpy()
    print(
        predict_out.shape, 
        predict_on_batch_out.shape, 
        predict_call_out.shape, 
        predict_step_out.shape
    )
    1/1 [==============================] - 1s 1s/step
    (1, 600, 600, 3) (1, 600, 600, 3) (1, 600, 600, 3) (1, 600, 600, 3)
    

    Logits Checking

    # OK
    np.testing.assert_allclose(
        predict_out,
        predict_on_batch_out,
        1e-5, 1e-5
    )
    
    # OK
    np.testing.assert_allclose(
        predict_on_batch_out,
        predict_call_out,
        1e-5, 1e-5
    )
    
    # OK
    np.testing.assert_allclose(
        predict_call_out,
        predict_step_out,
        1e-5, 1e-5
    )