pythontensorflowkerasneural-networkactivation-function

Different activation function based on input


I am trying to build a Keras neural network where the activation function at the output layer (conditionally) depends on the inputs. The activation function is quite complicated, so as a simpler example, consider the following:

def myactivation(y,x1,x2):
    if x1 > 1.2:
        return(tf.constant(0.0))
    else:
        return(tf.math.minimum(y,x2))

The network consists of an input layer (i.e. x1, x2), a hidden layer with output y and an output layer. The activation above corresponds to the activation of the output layer.

How would I implement something like this? Grateful for any help and guidance you might have!


Solution

  • You can write a custom layer like shown in the keras docs. For your application, the key is to use tf.where based on your condition x1 > 1.2 to switch between the two activation types. Regarding "organizing" the input to the layer, I've decided to stitch them together using tf.stack, such that there is a feature axis in the resulting tensor inp (you'll see in the example).

    import tensorflow as tf
    import keras
    
    # For reference: https://keras.io/guides/making_new_layers_and_models_via_subclassing/
    class AwesomeLayer(keras.layers.Layer):
        def __init__(self, threshold): # Here you could define the layer shape if needed
            super(AwesomeLayer, self).__init__()
            # Here you could define layer (trainable) parameters
            
            self.threshold = tf.constant(threshold)
    
        def call(self, inputs):
            return tf.where(
                inputs[..., 0] > self.threshold, 
                tf.constant(0.0), 
                tf.math.minimum(inputs[..., 2], inputs[..., 1])
            )
    

    Let's test if it does what you want:

    awesome_layer = AwesomeLayer(1.2)
    
    # Input Example 2
    
    x_1 = tf.constant(2.) # Case x_1 > 1.2
    x_2 = tf.constant(-5.)
    y = tf.constant(1.)
    
    inp = tf.stack([x_1, x_2, y], axis=-1)
    print(inp.shape) # (3,)
    
    outp = awesome_layer(inp)
    print(outp.numpy()) # 0.0
    
    # Input Example 2
    
    x_1 = tf.constant(-10.) # Case x_1 < 1.2
    x_2 = tf.constant(-5.)
    y = tf.constant(1.)
    
    inp = tf.stack([x_1, x_2, y], axis=-1)
    print(inp.shape) # (3,)
    
    outp = awesome_layer(inp)
    
    print(outp.numpy()) # -5.0
    
    # Input Example 3
    
    x_1 = tf.constant(-10.) # Case x_1 < 1.2
    x_2 = tf.constant(25.)
    y = tf.constant(1.)
    
    inp = tf.stack([x_1, x_2, y], axis=-1)
    print(inp.shape) # (3,)
    
    outp = awesome_layer(inp)
    
    print(outp.numpy()) # 1.0