pythondeep-learningjaxflax

How to restore a orbax checkpoint with jax/flax?


I saved a orbax checkpoint with the code below:

check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
                    step=iter_num,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state),
                        metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))

When I try to resume from the saved checkpoints, I used the code below to recover the state variable:

state, lr_schedule = init_train_state(model, params['params'], learning_rate, weight_decay, beta1, beta2, decay_lr, warmup_iters, 
                     lr_decay_iters, min_lr)  # Here state is the initialied state variable with type Train_state.
state = checkpoint_manager.restore(checkpoint_manager.latest_step(), items={'state': state})

But when I try to use the recovered state in the training loop, I got this error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:584, in shaped_abstractify(x)
    583 try:
--> 584   return _shaped_abstractify_handlers[type(x)](x)
    585 except KeyError:

KeyError: <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[40], line 37
     34 if iter_num == 0 and eval_only:
     35     break
---> 37 state, loss = train_step(state, get_batch('train'))
     39 # timing and logging
     40 t1 = time.time()

    [... skipping hidden 6 frame]

File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:575, in _shaped_abstractify_slow(x)
    573   dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
    574 else:
--> 575   raise TypeError(
    576       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    577       "does not have a dtype attribute")
    578 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    579                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'> as an abstract array; it does not have a dtype attribute

So, how should I correctly recover the state checkpoint and use it in the training loop?

Thanks!


Solution

  • You're mixing the old and new APIs in a way that is not allowed. Apologies that an error to that effect is not being raised, I can look into that.

    Your saving is correct, but I'd recommend that it look more like the following:

    with ocp.CheckpointManager(path, options=options, item_names=('state', 'metadata')) as mngr:
      mngr.save(
          step, 
          args=ocp.args.Composite(
              state=ocp.args.StandardSave(state),
              metadata=ocp.args.JsonSave(...),
          )
      )
    

    When restoring, you're currently using items which is part of the old API, and the usage is inconsistent with the CheckpointManager's definition, which is done based on the new API.

    item_names and args are hallmarks of the new API.

    You should do:

    with ocp.CheckpointManager(...) as mngr:
      mngr.restore(
          mngr.latest_step(), 
          args=ocp.args.Composite(
              state=ocp.args.StandardRestore(abstract_state),
          )
      )
    

    Let me know if there's any unexpected issues with that.