pythontensorflowtensorflow2.0gradient-descent

What is the analytic interpretation for a TensorFlow custom gradient?


In the official tf.custom_gradient documentation, it shows how to define custom gradients for log(1 + exp(x)):

@tf.custom_gradient
def log1pexp(x):
  e = tf.exp(x)
  def grad(dy):
    return dy * (1 - 1 / (1 + e))
  return tf.math.log(1 + e), grad

When y = log(1 + exp(x)), analytically the derivative comes out to be dy/dx = (1 - 1 / (1 + exp(x))).

However, in the code, def grad says its dy * (1 - 1 / (1 + exp(x))). dy/dx = dy * (1 - 1 / (1 + exp(x))) is not a valid equation. While dx = dy * (1 - 1 / (1 + exp(x))) is wrong as it should be the reciprocal.

What does the grad function equate to?


Solution

  • I finally figured it out. The dy should be called upstream_gradient or upstream_dy_dx.

    By the chain rule, we know that

    Chain rule

    where dx[i]/dx[i+1] is the gradient of the current function.

    So dy is the product of all the gradients upstream before this function.

    Enter image description here

    So, if you forget to multiply the dy, it is effectively the same as tf.stop_gradient

    Here is code which demos this. A full notebook is here.

    @tf.custom_gradient
    def foo(x):
        tf.debugging.assert_rank(x, 0)
    
        def grad(dy_dx_upstream):
            dy_dx = 2 * x
            dy_dx_downstream = dy_dx * dy_dx_upstream
            tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
            return dy_dx_downstream
    
        y = x ** 2
        tf.print(f'x={x}\ty={y}')
    
        return y, grad
    
    
    x = tf.constant(2.0, dtype=tf.float32)
    
    with tf.GradientTape(persistent=True) as tape:
        tape.watch(x)
        y = foo(foo(foo(x))) # y = x ** 8
    
    tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')
    

    Output

    x=2.0    y=4.0
    x=4.0    y=16.0
    x=16.0    y=256.0
    x=16.0    upstream=1.0    current=32.0        downstream=32.0
    x=4.0    upstream=32.0    current=8.0        downstream=256.0
    x=2.0    upstream=256.0    current=4.0        downstream=1024.0
    
    final dy/dx=1024.0