Suppose I have a vector of parameters p
which parameterizes a set of matrices A_1(p), A_2(p),...,A_N(p)
. I have a computation in which for some list of indices q
of length M
, I have to compute A_{q_M} * ... * A_{q_2} * A_{q_1} * v
for several different q
s. Each q
has a different length, but crucially doesn't change! What changes, and what I wish to take gradients against is p
.
I'm trying to figure out how to convert this to performant JAX. One way to do it is to have some large matrix Q
which contains all the different q
s on each row, padded out with identity matrices such that each multiplication chain is the same length, and then scan
over a function that switch
es between N
different functions doing matrix-vector multiplications by A_n(p)
.
However -- I don't particularly like the idea of this padding. Also, since Q
here is fixed, is there potentially a smarter way to do this? The distribution of lengths of q
s has a very long tail, so Q
will be dominated by padding.
EDIT: Here's a (edit 2: functional) minimal example
sigma0 = jnp.eye(2)
sigmax = jnp.array([[0, 1], [1, 0]])
sigmay = jnp.array([[0, -1j], [1j, 0]])
sigmaz = jnp.array([[1, 0], [0, -1]])
sigma = jnp.array([sigmax, sigmay, sigmaz])
def gates_func(params):
theta = params["theta"]
epsilon = params["epsilon"]
n = jnp.array([jnp.cos(theta), 0, jnp.sin(theta)])
omega = jnp.pi / 2 * (1 + epsilon)
X90 = expm(-1j * omega * jnp.einsum("i,ijk->jk", n, sigma) / 2)
return {
"Z90": expm(-1j * jnp.pi / 2 * sigmaz / 2),
"X90": X90
}
def multiply_out(params):
gate_lists = [["X90", "X90"], ["X90","Z90"], ["Z90", "X90"], ["X90","Z90","X90"]]
gates = gates_func(params)
out = jnp.zeros(len(gate_lists))
for i, gate_list in enumerate(gate_lists):
init = jnp.array([1.0,0.0], dtype=jnp.complex128)
for g in gate_list:
init = gates[g] @ init
out = out.at[i].set(jnp.abs(init[0]))
return out
params = dict(theta=-0.0, epsilon=0.001)
multiply_out(params)
The main issue here is that JAX does not support string inputs. But you can use NumPy to manipulate string arrays and turn them into integer categorical arrays that can then be used by jax.jit
and jax.vmap
. The solution might look something like this:
import numpy as np
def gates_func_int(params, gate_list_vals):
g = gates_func(params)
identity = jnp.eye(*list(g.values())[0].shape)
return jnp.stack([g.get(val, identity) for val in gate_list_vals])
@jax.jit
def multiply_out_2(params):
# compile-time pre-processing
gate_lists = [["X90", "X90"], ["X90","Z90"], ["Z90", "X90"], ["X90","Z90","X90"]]
max_size = max(map(len, gate_lists))
gate_array = np.array([gates + [''] * (max_size - len(gates))
for gates in gate_lists])
gate_list_vals, gate_list_ints = np.unique(gate_array, return_inverse=True)
gate_list_ints = gate_list_ints.reshape(gate_array.shape)
# runtime computation
gates = gates_func_int(params, gate_list_vals)[gate_list_ints]
initial = jnp.array([[1.0],[0.0]], dtype=jnp.complex128)
return jax.vmap(lambda g: jnp.abs(jnp.linalg.multi_dot([*g, initial]))[0])(gates).ravel()
multiply_out_2(params)