In the official tf.custom_gradient documentation, it shows how to define custom gradients for log(1 + exp(x))
:
@tf.custom_gradient
def log1pexp(x):
e = tf.exp(x)
def grad(dy):
return dy * (1 - 1 / (1 + e))
return tf.math.log(1 + e), grad
When y = log(1 + exp(x))
, analytically the derivative comes out to be dy/dx = (1 - 1 / (1 + exp(x)))
.
However, in the code, def grad
says its dy * (1 - 1 / (1 + exp(x)))
.
dy/dx = dy * (1 - 1 / (1 + exp(x)))
is not a valid equation. While dx = dy * (1 - 1 / (1 + exp(x)))
is wrong as it should be the reciprocal.
What does the grad
function equate to?
I finally figured it out. The dy
should be called upstream_gradient
or upstream_dy_dx
.
By the chain rule, we know that
where dx[i]/dx[i+1]
is the gradient of the current function.
So dy
is the product of all the gradients upstream before this function.
So, if you forget to multiply the dy
, it is effectively the same as tf.stop_gradient
Here is code which demos this. A full notebook is here.
@tf.custom_gradient
def foo(x):
tf.debugging.assert_rank(x, 0)
def grad(dy_dx_upstream):
dy_dx = 2 * x
dy_dx_downstream = dy_dx * dy_dx_upstream
tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
return dy_dx_downstream
y = x ** 2
tf.print(f'x={x}\ty={y}')
return y, grad
x = tf.constant(2.0, dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = foo(foo(foo(x))) # y = x ** 8
tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')
Output
x=2.0 y=4.0
x=4.0 y=16.0
x=16.0 y=256.0
x=16.0 upstream=1.0 current=32.0 downstream=32.0
x=4.0 upstream=32.0 current=8.0 downstream=256.0
x=2.0 upstream=256.0 current=4.0 downstream=1024.0
final dy/dx=1024.0