pythonjaxautodiff

How to write a JAX custom vector-Jacobian product (vjp) for softmax


In order to understand JAX's reverse mode auto-diff I tried to write a custom_vjp for softmax like this:

import jax
import jax.numpy as jnp
import numpy as np

@jax.custom_vjp
def stablesoftmax(x):
    print(f"input: {x} shape: {x.shape}")
    expc = jnp.exp(x - jnp.amax(x))
    return expc / jnp.sum(expc)

def ssm_fwd(x):
    s = stablesoftmax(x)
    return s, s

def ssm_bwd(acts, d_dacts):
    dacts_dinput = jnp.diag(acts) - jnp.outer(acts, acts)  # Jacobian
    d_dinput = jnp.dot(d_dacts, dacts_dinput)  # Vector-Jacobian product
    print(f"Saved activations:\n{acts} shape: {acts.shape}")
    print(f"d/d_acts:\n{d_dacts} shape: {d_dacts.shape}")
    print(f"d_acts/d_input (Jacobian of softmax):\n{dacts_dinput} shape: {dacts_dinput.shape}")
    print(f"d/d_input:\n{d_dinput} shape: {d_dinput.shape}")
    return d_dinput

stablesoftmax.defvjp(ssm_fwd, ssm_bwd)

print(f"JAX version: {jax.__version__}")
y = np.array([1., 2., 3.])
a = stablesoftmax(y)
softmax_jac_fun = jax.jacrev(stablesoftmax)
dsoftmax_dy = softmax_jac_fun(y)
print(f"Softmax Jacobian: {dsoftmax_dy}")

But when I call jacrev I get errors about the structure of the VJP result not matching the structure of the input to the softmax:

JAX version: 0.2.13
input: [1. 2. 3.] shape: (3,)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
input: [1. 2. 3.] shape: (3,)
Saved activations:
[0.09003057 0.24472848 0.66524094] shape: (3,)
d/d_acts:
Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = array([[1., 0., 0.],
                    [0., 1., 0.],
                    [0., 0., 1.]], dtype=float32)
       batch_dim = 0 shape: (3,)
d_acts/d_input (Jacobian of softmax):
[[ 0.08192507 -0.02203305 -0.05989202]
 [-0.02203305  0.18483645 -0.1628034 ]
 [-0.05989202 -0.1628034   0.22269544]] shape: (3, 3)
d/d_input:
Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[ 0.08192507, -0.02203305, -0.05989202],
                          [-0.02203305,  0.18483645, -0.1628034 ],
                          [-0.05989202, -0.1628034 ,  0.22269544]], dtype=float32)
       batch_dim = 0 shape: (3,)
Traceback (most recent call last):
  File "analysis/vjp_test.py", line 30, in <module>
    dsoftmax_dy = softmax_jac_fun(y)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef(*) for primal input structure PyTreeDef((*,)).

However you can see when I print the shapes they both have shape (3,) but JAX doesn't seem to agree? (Actually the input and output are 3 by 3 matrices but this is because JAX is trying to vmap the JVPs in jacrev so pulling back the entire basis for R(3) in one go (i.e. a 3x3 identity matrix).

Note: I get the same error if I use jax.grad or jax.vjp directly.


Solution

  • According to the custom_vjp docs:

    The output of bwd must be a tuple of length equal to the number of arguments of the primal function

    So the return statement in the backward pass should look like this:

    def ssm_bwd(acts, d_dacts):
        ...
        return (d_dinput,)