I am very new to JAX. Please excuse me if this something obvious or I am making some stupid mistake. I am trying to implement a function which does the following. All these functions will be called from other JIT-ed function. So, removing JIT may not be possible.
Here, is the entire code:
import jax
import jax.numpy as jnp
from jax import random
from jax import lax
import copy
from copy import deepcopy
import numpy as np
def get_condition(state, x, y):
L = (jnp.sqrt(len(jnp.asarray(state)))).astype(int)
state = jnp.reshape(state, (L,L), order="F")
s1 = state[x, y]
branches = [lambda : (0,1), lambda : (1,0), lambda : (0,0)]
conditions = jnp.array([s1==2, s1==4, True])
result = lax.switch(jnp.argmax(conditions), branches)
return tuple(x for x in result)
def update_state_vec(state, x, y, condition, list_scattered_states):
L = (jnp.sqrt(len(state))).astype(int)
def update_state_4(list_scattered_states):
state1 = jnp.array( jnp.reshape(deepcopy(state), (L, L), order="F"))
state1 = state1.at[x, y].set(4)
list_scattered_states.append(jnp.ravel(state1, order="F"))
return list_scattered_states
def update_state_2(list_scattered_states):
state1 = jnp.array( jnp.reshape(deepcopy(state), (L, L), order="F"))
state1 = state1.at[x, y].set(2)
list_scattered_states.append(jnp.ravel(state1, order="F"))
return list_scattered_states
def no_update_state (list_scattered_states):
#state1 = jnp.ravel(state, order="F")
#list_scattered_states.append(jnp.ravel(state, order="F"))
#This doesn't work---------------------------------
return list_scattered_states
conditions = jnp.array([condition == (1, 0), condition == (0, 1), condition == (0, 0)])
print(conditions)
branches = [update_state_4, update_state_2,no_update_state]
return(lax.switch(jnp.argmax(conditions), branches, operand=list_scattered_states))
def get_elements(state):
L = (jnp.sqrt(len(state))).astype(int)
list_scattered_states = []
for x in range(L):
for y in range(L):
condition=get_condition(state, x, y)
print(condition)
list_scattered_states = update_state_vec(state, x, y, condition, list_scattered_states)
return list_scattered_states
We can take an example input as follows,
arr=jnp.asarray([2., 1., 3., 4., 1., 2., 3., 4., 4., 1., 2., 3., 4., 2., 1., 3.])
get_elements(arr)
I get an error message as below:
print(conditions)
41 branches = [update_state_4, update_state_2,no_update_state]
---> 43 return(lax.switch(jnp.argmax(conditions), branches,
operand=list_scattered_states))
TypeError: branch 0 and 2 outputs must have same type structure, got PyTreeDef([*])
and PyTreeDef([]).
So, the error is coming from the face that no_update_state is returning something that doesn't match with return type of update_state_4 or update_state_2. I am quite clueless at this point. Any help will be much appreciated.
The root of the issue here is that under transformations like jit
, vmap
, switch
, etc. JAX requires the shape of outputs to be known statically, i.e. at compile time (see JAX sharp bits: dynamic shapes). In your case, the functions you are passing to switch
return outputs of different shapes, and since jnp.argmax(conditions)
is not known at compile time, there's no way for the compiler to know what memory to allocate for the result of this function.
Since you're not JIT-compiling or otherwise transforming your code, the easiest way to address this would be to replace the lax.switch
statement with this:
if condition == (1, 0):
list_scattered_states = update_state_4(list_scattered_states)
elif condition == (0, 1):
list_scattered_states = update_state_2(list_scattered_states)
return list_scattered_states
If you do want your function to be compatible with jit
or other JAX transformations, you'll have to re-write the logic so that the size of list_scattered_states
remains constant, e.g. by padding it to the expected size from the beginning.