I would like to use the package Oryx to invert an affine transformation written in JAX. The transformation maps x->y
and depends on a set of adjustable parameters (which I call params
). Specifically, the affine transformation is defined as:
import jax.numpy as jnp
def affine(params, x):
return x * params['scale'] + params['shift']
params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
y_out = affine(params, x_in)
I would like to invert affine
wrt to input x
as a function of params
. Oryx has a function oryx.core.inverse
to invert JAX functions. However, inverting a function with parameters, like this:
import oryx
oryx.core.inverse(affine)(params, y_out)
doesn't work (AssertionError: length mismatch: [1, 3]
), presumably because inverse
doesn't know that I want to invert y_out
but not params
.
What is the most elegant way to solve this problem for all possible values (i.e., as a function) of params
using oryx.core.inverse
?
I find the inverse docs not very illuminating.
Update:
Jakevdp gave an excellent suggestion for a given set of params
. I've clarified the question to indicate that I am wondering how to define the inverse as a function of params
.
You can do this by closing over the static parameters, for example using partial
:
from functools import partial
x = oryx.core.inverse(partial(affine, params))(y_out)
print(x)
# 3.0
Edit: if you want a single inverted function to work for multiple values of params
, you will have to return params
in the output (otherwise, there's no way from a single output value to infer all three inputs). It might look something like this:
def affine(params, x):
return params, x * params['scale'] + params['shift']
params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
_, y_out = affine(params, x_in)
_, x = oryx.core.inverse(affine)(params, y_out)
print(x)
# 3.0