jax

Serialization in JAX


What is the recommended way to do serialization/deserialization in JAX?

In the context of reinforcement learning my starting point in terms of data might be e.g. match replays that have to be pre-processed to obtain tuples of JAX arrays. This is a process that I would like to do just once, save to disk, then wrap around that a data loading interface like grain.DataLoader.

But I have no idea how to do the actual serialization.

Right now I'm doing

def save_jax(path, x: jnp.array):
    y = np.array(x)
    np.save(path, y)

def load_jax(path):
    with open(path, "br") as f:
        x = np.load(f)
    y = jnp.array(x)
    return y

It works. My hope is there won't be any copies in wrapping jnp->np or the other way around, and then hopefully this will be mmapped.

Is this approach bad for performance? What is a better way?


Solution

  • Both Jax and Numpy support the Python buffer protocol, allowing for zero-copy data sharing between the different types. That being said when using jnp.array or np.array, the default is to copy (see linked docs). So right now you do an unnecessary copy of the data. So I would suggest to use np.asarray instead (and the Jax equivalent), which only copies when needed. So your code would look like:

    def save_jax(path, x: jnp.array):
        y = np.asarray(x)
        np.save(path, y)
    
    def load_jax(path):
        with open(path, "br") as f:
            x = np.load(f)
        y = jnp.asarray(x)
        return y
    

    Aside from the copy the native Numpy format might not be the best choice of format in neither I/O performance nor associated meta data. For a lightweight alternative you might want to look for example into safetensors.

    I hope this helps!