pythonjaxflax

Restoring flax model checkpoints using orbax throws ValueError


The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory.


from flax.training import orbax_utils
import orbax.checkpoint

directory_gen_path = "checkpoints_loc"
orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer()
gen_options = orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=5, create=True)
gen_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    directory_gen_path, orbax_checkpointer_gen, gen_options
)

def save_model_checkpoints(step_, generator_state, generator_batch_stats):

    gen_ckpt = {
        "model": generator_state,
        "batch_stats": generator_batch_stats,
    }

    save_args_gen = orbax_utils.save_args_from_target(gen_ckpt)
    gen_checkpoint_manager.save(step_, gen_ckpt, save_kwargs={"save_args": save_args_gen})

def load_model_checkpoints(generator_state, generator_batch_stats):
    gen_target = {
        "model": generator_state,
        "batch_stats": generator_batch_stats,
    }

    latest_step = gen_checkpoint_manager.latest_step()
    gen_ckpt = gen_checkpoint_manager.restore(latest_step, items=gen_target)
    generator_state = gen_ckpt["model"]
    generator_batch_stats = gen_ckpt["batch_stats"]

    return generator_state, generator_batch_stats


The training of the model was done on a GPU and loading the state onto GPU device works fine, however, when trying to load the model to cpu, the following error is being thrown by the orbax checkpoint manager's restore method

ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().

I'm not quite sure what could be the reason, any thoughts folks?

Update: Updated to the latest version of orbax-checkpoint, 0.8.0 traceback changed to the following error

ValueError: sharding passed to deserialization should be specified, concrete and an instance of `jax.sharding.Sharding`. Got None

Solution

  • What version of orbax.checkpoint are you using?

    It looks like this issue was fixed in https://github.com/google/orbax/issues/678 – you should update to the most recent version of orbax-checkpoint, and try running your code again. If that doesn't work, I'd suggest reporting the problem at https://github.com/google/orbax/issues/new