pythontensorflowkerashessian-matrix

hessian matrix of a keras model with tf.hessians


I want to compute the Hessian matrix of a keras model w.r.t. its input in graph mode using tf.hessians. Here is a minimal example

import tensorflow as tf
from tensorflow import keras

model = keras.Sequential([
    keras.Input((10,)),
    keras.layers.Dense(1)
])
model.summary()

@tf.function
def get_grads(inputs):
    loss = tf.reduce_sum(model(inputs))
    return tf.gradients(loss, inputs)

@tf.function
def get_hessian(inputs):
    loss = tf.reduce_sum(model(inputs))
    return tf.hessians(loss, inputs)

batch_size = 3
test_input = tf.random.uniform((batch_size, 10))
out = model(test_input) # works fine
grads = get_grads(test_input) # works fine
hessian = get_hessian(test_input) # raises ValueError: None values not supported.

While the forward pass and the get_grads function work fine, the get_hessian function raises the ValueError: None values not supported..

In this example

@tf.function
def get_hessian_(inputs):
    loss = tf.reduce_sum(inputs**2)
    return tf.hessians(loss, inputs)

get_hessian_(tf.random.uniform((3,)))[0]
# <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
# array([[2., 0., 0.],
#        [0., 2., 0.],
#        [0., 0., 2.]], dtype=float32)>

tf.hessians yields the expected result without error.


Solution

  • In your code example,

    you are trying to get hessian of f(x)(model outputs) w.r.t. x(inputs) and f is linear (the model is linear).

    Hessian of f(x) w.r.t. x should actually be a zero tensor, but tf.hessians can't handle that properly, resulting the error. Adding an additional layer with non-linear activation will eliminate the error.

    Codes examples:

    Using tf.hessians to get hessian:

    model = tf.keras.Sequential([
        Dense(10,activation='sigmoid'), #remove this line and you will get error
        Dense(1)
    ])
    @tf.function
    def get_hessian(inputs):
        loss = tf.reduce_sum(model(inputs))
        return tf.hessians(loss, inputs)
    
    batch_size = 3
    tf.random.set_seed(123)
    test_input = tf.random.uniform((3,10),minval=1.5,maxval=2.5)
    hessian = get_hessian(test_input)
    print(type(hessian))
    print(len(hessian))
    print(hessian[0].shape)
    print(hessian[0][0,0,0,0])
    print(hessian[0][0,0,0,1])
    '''
    <class 'list'>
    1
    (3, 10, 3, 10)
    tf.Tensor(0.0028595054, shape=(), dtype=float32)
    tf.Tensor(0.0009458237, shape=(), dtype=float32)
    ''' 
    

    Using tf.GradientTape() to get hessian:

    model = tf.keras.Sequential([
        Dense(10,activation='sigmoid'), #remove this line and get_hessian return None
        Dense(1)
    ])
    @tf.function
    def get_hessian(inputs):
        with tf.GradientTape() as t2:
          t2.watch(inputs)
          with tf.GradientTape() as t1:
            t1.watch(inputs)
            loss = tf.reduce_sum(model(inputs))
          g=t1.gradient(loss,inputs)
        return t2.jacobian(g,inputs)
    
    batch_size = 3
    tf.random.set_seed(123)
    test_input = tf.random.uniform((3,10),minval=1.5,maxval=2.5)
    hessian = get_hessian(test_input)
    print(type(hessian))
    print(hessian.shape if hessian is not None else None)
    print(hessian[0,0,0,0] if hessian is not None else None)
    print(hessian[0,0,0,1] if hessian is not None else None)
    '''
    <class 'tensorflow.python.framework.ops.EagerTensor'>
    (3, 10, 3, 10)
    tf.Tensor(0.0028595058, shape=(), dtype=float32)
    tf.Tensor(0.0009458238, shape=(), dtype=float32)
    '''
    

    In case you want to get a zero tensor, you can use unconnected_gradients=tf.UnconnectedGradients.ZERO

    model = tf.keras.Sequential([
        Dense(1)
    ])
    @tf.function
    def get_hessian(inputs):
        with tf.GradientTape() as t2:
          t2.watch(inputs)
          with tf.GradientTape() as t1:
            t1.watch(inputs)
            loss = tf.reduce_sum(model(inputs))
          g=t1.gradient(loss,inputs,unconnected_gradients=tf.UnconnectedGradients.ZERO)
        return t2.jacobian(g,inputs,unconnected_gradients=tf.UnconnectedGradients.ZERO)
    
    batch_size = 3
    tf.random.set_seed(123)
    test_input = tf.random.uniform((3,10),minval=1.5,maxval=2.5)
    hessian = get_hessian(test_input)
    print(type(hessian))
    print(hessian.shape)
    print(tf.math.count_nonzero(hessian))
    '''
    <class 'tensorflow.python.framework.ops.EagerTensor'>
    (3, 10, 3, 10)
    tf.Tensor(0, shape=(), dtype=int64)
    '''