jaxflax

How to select between different function based on a value of a parameter in flax?


I am iterating through each head and applying either f1 or f2 function depending on the value of the parameter self.alpha.

I only want to evaluate either function f1 or f2 not both and then select output of one based on conditional.

        def f1 (x):
            print('f1')
            return x/x.shape[2]
        def f2 (x):
            print('f2')
            temp = nn.relu(x)
            return temp/(jnp.sum(temp,axis=-1,keepdims=True) + 1e-5)
        
        def choose_attention(alpha, x):
            return jax.lax.cond(alpha[0, 0, 0,0],f2,f1,operand=x)
        
        results = []
        func = [f1,f2]
        for i in range(self.alpha.shape[1]):
            print(i)
            alpha_i = self.alpha[:, i:i+1, :, :]
            x_i = attn_weights[:, i:i+1, :, :]
            result_i = jax.lax.switch(self.alpha[0,0,0,0].astype(int),func,x_i)
            results.append(result_i)

        final_result = jnp.concatenate(results, axis=1)

My print statements read like 0 f1 f2 1 2 3 4 5 6 7 8 9 10 11


Solution

  • jax.lax.switch does what you want: it chooses between two different functions based on a runtime value. Your use of print statements is misleading you: Python print runs at trace-time rather than runtime, and all code will be traced even if it is not eventually executed.

    For some background on how to think about the execution model of JAX programs, I would suggest How to think in JAX.

    Side note: for better performance, I would also suggest avoiding using Python for loops to loop through array values, and instead express your algorithm using either Numpy-style explicit vectorization, or using jax.vmap to automatically vectorize your code.