numpyindexingvectorizationjax

Jax numpy extracting non-nan values gives NonConcreteBooleanIndexError


I have a jax 2d array with some nan-values

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])

and want to get an array which contains for each row only the non-nan values. The resulting array has thus the same number of rows, and either less columns or the same number but with nan values padded at the end. So in this case, the result should be

array_2d = jnp.array([
    [1,   2,      3],
    [10  20,jnp.nan]
    ])

The order (among non-nan values) should stay the same.

To make things easier, I know that each row has at most k (in this case 3) non-nan values. Getting the indices for the non-nan values is very easy, but ``moving them to the front'' is harder.

I tried to work on a row-by-row basis; the following function works indeed:

# we want to vmap this over each row
def get_non_nan_values(row_vals):
    ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
    row_mask = ~jnp.isnan(row_vals)
    ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
    ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
    return ret_arr

# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]

However, I can't vmap this. Even though I payed attention that the returned array always has the same size, the line ret_vals = row_vals[row_mask] makes problems, since this has a dynamic size. Does anyone know how to circumvent this? I believe that functions like `jnp.where' etc don't help either.

Here is the full MWE:

import jax.numpy as jnp

array_2d = jnp.array([
    [jnp.nan,        1,       2,   jnp.nan,    3],
    [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
    ])

# we want to get -- efficiently -- all non-nan values per row.
# we know that each row has at most 3 non-nan values

# we will vmap this over each row
def get_non_nan_values(row_vals):
    ret_arr = jnp.zeros(3) # there are at most 3 non-nan values per row
    row_mask = ~jnp.isnan(row_vals)
    ret_vals = row_vals[row_mask] # this gets all (at most 3) non-nan values. However, the size here is dynamically. This throws after vmapping NonConcreteBooleanIndexError error.
    ret_arr = ret_arr.at[:ret_vals.shape[0]].set(ret_vals) # this returns a FIXED SIZE array
    return ret_arr

# the following works:
get_non_nan_values(array_2d[0,:]) # should return [1,2,3]

# we now vmap
non_nan_vals = jax.vmap(get_non_nan_values)(array_2d) # this gives error: NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

NB: The array will be very large in practice and have many nan values, while k (the number of non-nan values) is on the order of 10 or 100.

Thank you very much!


Solution

  • By padding the array with a fill value at the end of each row first, you can rely on jnp.nonzero and its size and fill_value arguments, which define a fixed output size and fill value index, when the size requirement is not met. Here is a minimal example:

    import jax.numpy as jnp
    import jax
    
    array_2d = jnp.array([
        [jnp.nan,        1,       2,   jnp.nan,    3],
        [10     ,jnp.nan,   jnp.nan,        20,jnp.nan]
        ])
    
    
    @jax.vmap
    def get_non_nan_values(row_vals, size=3):
        padded = jnp.pad(row_vals, (0, 1), constant_values=jnp.nan)
        non_nan = jnp.nonzero(~jnp.isnan(padded), size=size, fill_value=-1)
        return padded[non_nan]
    
    get_non_nan_values(array_2d)
    

    Which returns:

    Array([[ 1.,  2.,  3.],
           [10., 20., nan]], dtype=float32)
    

    I think this solution is a bit more compact and clearer in intend, however I have not checked the performance.

    I hope this helps!