I saved a orbax checkpoint with the code below:
check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
step=iter_num,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))
When I try to resume from the saved checkpoints, I used the code below to recover the state
variable:
state, lr_schedule = init_train_state(model, params['params'], learning_rate, weight_decay, beta1, beta2, decay_lr, warmup_iters,
lr_decay_iters, min_lr) # Here state is the initialied state variable with type Train_state.
state = checkpoint_manager.restore(checkpoint_manager.latest_step(), items={'state': state})
But when I try to use the recovered state in the training loop, I got this error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:584, in shaped_abstractify(x)
583 try:
--> 584 return _shaped_abstractify_handlers[type(x)](x)
585 except KeyError:
KeyError: <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'>
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[40], line 37
34 if iter_num == 0 and eval_only:
35 break
---> 37 state, loss = train_step(state, get_batch('train'))
39 # timing and logging
40 t1 = time.time()
[... skipping hidden 6 frame]
File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:575, in _shaped_abstractify_slow(x)
573 dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
574 else:
--> 575 raise TypeError(
576 f"Cannot interpret value of type {type(x)} as an abstract array; it "
577 "does not have a dtype attribute")
578 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
579 named_shape=named_shape)
TypeError: Cannot interpret value of type <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'> as an abstract array; it does not have a dtype attribute
So, how should I correctly recover the state
checkpoint and use it in the training loop?
Thanks!
You're mixing the old and new APIs in a way that is not allowed. Apologies that an error to that effect is not being raised, I can look into that.
Your saving is correct, but I'd recommend that it look more like the following:
with ocp.CheckpointManager(path, options=options, item_names=('state', 'metadata')) as mngr:
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
metadata=ocp.args.JsonSave(...),
)
)
When restoring, you're currently using items
which is part of the old API, and the usage is inconsistent with the CheckpointManager
's definition, which is done based on the new API.
item_names
and args
are hallmarks of the new API.
You should do:
with ocp.CheckpointManager(...) as mngr:
mngr.restore(
mngr.latest_step(),
args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_state),
)
)
Let me know if there's any unexpected issues with that.