pythonkeraspytorchloss-functionkeras-3

Keras 3 Custom Loss Function to mask NaN


I am trying to build a custom Loss function on Keras 3, which would be used either in jax or torch backend. I want to mask out of y_pred and y_true all the indices where y_true is a certain value. Passing the remaining values to a given loss_function.

But everytime I try to fit a model with my loss function both with jax backend or torch, it breaks, with pretty much saying that it cant take the indices or make the masking. Because for that it would need to access the values on the tensor.

I am using two ways:



import keras
from keras import Loss, ops



class NanValueLossA(Loss):
    def __init__(
        self,
        loss_to_use=None,
        nan_value=None,
        name="nan_value_loss",
        **kwargs,
    ):
        self.nan_value = nan_value
        self.loss_to_use=loss_to_use
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):

        valid_mask = ops.not_equal(y_true, self.nan_value)
        return self.loss_to_use(y_true[valid_mask], y_pred[valid_mask])
    


class NanValueLossB(Loss):
    def __init__(
        self,
        loss_to_use=None,
        nan_value=None,
        name="nan_value_loss",
        **kwargs,
    ):
        self.nan_value = nan_value
        self.loss_to_use=loss_to_use
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):

        valid_mask = ops.not_equal(y_true, self.nan_value)
        valid_indices = ops.where(valid_mask)
        masked_y_pred = ops.take(y_pred,valid_indices)
        masked_y_true = ops.take(y_true,valid_indices)

        return self.loss_to_use(masked_y_true, masked_y_pred)

I have tried these two forms in both jax and torch. I have tryed a couple of other ways, but the problem is the same every time. Here are the erros:

NaNValueLossA: torch:

  File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
    x = x.to(device)
        ^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

jax:

  File "c:....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6976, in _expand_bool_indices
    raise errors.NonConcreteBooleanIndexError(abstract_i)
jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[32,1,128,128,1])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

NaNValueLossB: torch:

  File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
    x = x.to(device)
        ^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

jax:

  File "C:....\advanced_losses.py", line 651, in call
    valid_indices = ops.where(valid_mask)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 1946, in where
    return nonzero(condition, size=size, fill_value=fill_value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2378, in nonzero
    calculated_size = core.concrete_dim_or_error(calculated_size,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function wrapped_fn at c:.....\Lib\site-packages\keras\src\backend\jax\core.py:153 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Before keras 3, I used a tensorflow based loss function and it worked, but now I want something to work with torch. This was my tensorflow implementation:

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K




def nan_mean_squared_error_loss(nan_value=np.nan):
    # Create a loss function
    def loss(y_true, y_pred):
        indices = tf.where(tf.not_equal(y_true, nan_value))
        return tf.keras.losses.mean_squared_error(
            tf.gather_nd(y_true, indices), tf.gather_nd(y_pred, indices)
        )

    # Return a function
    return loss

Solution

  • I did not have tensorflow installed in the environment I was using. I had:

    keras 3.4.1 torch 2.3.1 torchaudio 2.3.1 torchvision 0.18.1

    But I installed tensorflow to test with it, and now running with torch works. I guess there was some backend functions that torch can fetch from tensorflow.

    tensorflow 2.16.2

    Now solution NaNValueLossA is working!