pythonray

ImportError: cannot import name 'Checkpoint' from 'ray.air'


I'm trying to follow this tutorial to tune hyperparameters in PyTorch using Ray, copy-pasted everything but I get the following error:

ImportError: cannot import name 'Checkpoint' from 'ray.air'

from this line of import:

from ray.air import Checkpoint

I installed ray using pip install -U "ray[tune]" as suggested on the official website. After getting the error, to be sure, I also tried a more general pip install ray, which did not fix anything.
I have version ray==2.9.0 installed.

Any help, please?


Solution

  • Try to install older version 2.7.0:

    pip install ray[tune]==2.7.0
    

    Update :

    For the newest version the Ray AIR session is replaced with a Ray Train context object. You can import Checkpoint using :

    from ray.train import Checkpoint
    

    You need to adjust your code as follow:

    from ray import air, train
    
    # Ray Train methods and classes:
    air.session.report               -> train.report
    air.session.get_dataset_shard    -> train.get_dataset_shard
    air.session.get_checkpoint       -> train.get_checkpoint
    air.Checkpoint                   -> train.Checkpoint
    air.Result                       -> train.Result
    
    # Ray Train configurations:
    air.config.CheckpointConfig      -> train.CheckpointConfig
    air.config.FailureConfig         -> train.FailureConfig
    air.config.RunConfig             -> train.RunConfig
    air.config.ScalingConfig         -> train.ScalingConfig
    
    # Ray TrainContext methods:
    air.session.get_experiment_name  -> train.get_context().get_experiment_name
    air.session.get_trial_name       -> train.get_context().get_trial_name
    air.session.get_trial_id         -> train.get_context().get_trial_id
    air.session.get_trial_resources  -> train.get_context().get_trial_resources
    air.session.get_trial_dir        -> train.get_context().get_trial_dir
    air.session.get_world_size       -> train.get_context().get_world_size
    air.session.get_world_rank       -> train.get_context().get_world_rank
    air.session.get_local_rank       -> train.get_context().get_local_rank
    air.session.get_local_world_size -> train.get_context().get_local_world_size
    air.session.get_node_rank        -> train.get_context().get_node_rank
    

    For more informations see :