The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory.
from flax.training import orbax_utils
import orbax.checkpoint
directory_gen_path = "checkpoints_loc"
orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer()
gen_options = orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=5, create=True)
gen_checkpoint_manager = orbax.checkpoint.CheckpointManager(
directory_gen_path, orbax_checkpointer_gen, gen_options
)
def save_model_checkpoints(step_, generator_state, generator_batch_stats):
gen_ckpt = {
"model": generator_state,
"batch_stats": generator_batch_stats,
}
save_args_gen = orbax_utils.save_args_from_target(gen_ckpt)
gen_checkpoint_manager.save(step_, gen_ckpt, save_kwargs={"save_args": save_args_gen})
def load_model_checkpoints(generator_state, generator_batch_stats):
gen_target = {
"model": generator_state,
"batch_stats": generator_batch_stats,
}
latest_step = gen_checkpoint_manager.latest_step()
gen_ckpt = gen_checkpoint_manager.restore(latest_step, items=gen_target)
generator_state = gen_ckpt["model"]
generator_batch_stats = gen_ckpt["batch_stats"]
return generator_state, generator_batch_stats
The training of the model was done on a GPU and loading the state onto GPU device works fine, however, when trying to load the model to cpu, the following error is being thrown by the orbax checkpoint manager's restore method
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
I'm not quite sure what could be the reason, any thoughts folks?
Update: Updated to the latest version of orbax-checkpoint, 0.8.0 traceback changed to the following error
ValueError: sharding passed to deserialization should be specified, concrete and an instance of `jax.sharding.Sharding`. Got None
What version of orbax.checkpoint
are you using?
It looks like this issue was fixed in https://github.com/google/orbax/issues/678 – you should update to the most recent version of orbax-checkpoint
, and try running your code again. If that doesn't work, I'd suggest reporting the problem at https://github.com/google/orbax/issues/new