What is the recommended way to do serialization/deserialization in JAX?
In the context of reinforcement learning my starting point in terms of data might be e.g. match replays that have to be pre-processed to obtain tuples of JAX arrays. This is a process that I would like to do just once, save to disk, then wrap around that a data loading interface like grain.DataLoader
.
But I have no idea how to do the actual serialization.
Right now I'm doing
def save_jax(path, x: jnp.array):
y = np.array(x)
np.save(path, y)
def load_jax(path):
with open(path, "br") as f:
x = np.load(f)
y = jnp.array(x)
return y
It works. My hope is there won't be any copies in wrapping jnp->np or the other way around, and then hopefully this will be mmapped.
Is this approach bad for performance? What is a better way?
Both Jax and Numpy support the Python buffer protocol, allowing for zero-copy data sharing between the different types. That being said when using jnp.array
or np.array
, the default is to copy (see linked docs). So right now you do an unnecessary copy of the data. So I would suggest to use np.asarray
instead (and the Jax equivalent), which only copies when needed. So your code would look like:
def save_jax(path, x: jnp.array):
y = np.asarray(x)
np.save(path, y)
def load_jax(path):
with open(path, "br") as f:
x = np.load(f)
y = jnp.asarray(x)
return y
Aside from the copy the native Numpy format might not be the best choice of format in neither I/O performance nor associated meta data. For a lightweight alternative you might want to look for example into safetensors.
I hope this helps!