In the official tf.custom_gradient documentation, it shows how to define custom gradients for log(1 + exp(x)):
@tf.custom_gradient
def log1pexp(x):
e = tf.exp(x)
def grad(dy):
return dy * (1 - 1 / (1 + e))
return tf.math.log(1 + e), grad
When y = log(1 + exp(x)), analytically the derivative comes out to be dy/dx = (1 - 1 / (1 + exp(x))).
However, in the code, def grad says its dy * (1 - 1 / (1 + exp(x))).
dy/dx = dy * (1 - 1 / (1 + exp(x))) is not a valid equation. While dx = dy * (1 - 1 / (1 + exp(x))) is wrong as it should be the reciprocal.
What does the grad function equate to?
I finally figured it out. The dy should be called upstream_gradient or upstream_dy_dx.
By the chain rule, we know that
where dx[i]/dx[i+1] is the gradient of the current function.
So dy is the product of all the gradients upstream before this function.
So, if you forget to multiply the dy, it is effectively the same as tf.stop_gradient
Here is code which demos this. A full notebook is here.
@tf.custom_gradient
def foo(x):
tf.debugging.assert_rank(x, 0)
def grad(dy_dx_upstream):
dy_dx = 2 * x
dy_dx_downstream = dy_dx * dy_dx_upstream
tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
return dy_dx_downstream
y = x ** 2
tf.print(f'x={x}\ty={y}')
return y, grad
x = tf.constant(2.0, dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = foo(foo(foo(x))) # y = x ** 8
tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')
Output
x=2.0 y=4.0
x=4.0 y=16.0
x=16.0 y=256.0
x=16.0 upstream=1.0 current=32.0 downstream=32.0
x=4.0 upstream=32.0 current=8.0 downstream=256.0
x=2.0 upstream=256.0 current=4.0 downstream=1024.0
final dy/dx=1024.0