I'm trying to follow this tutorial to tune hyperparameters in PyTorch using Ray, copy-pasted everything but I get the following error:
ImportError: cannot import name 'Checkpoint' from 'ray.air'
from this line of import:
from ray.air import Checkpoint
I installed ray using pip install -U "ray[tune]"
as suggested on the official website. After getting the error, to be sure, I also tried a more general pip install ray
, which did not fix anything.
I have version ray==2.9.0
installed.
Any help, please?
Try to install older version 2.7.0
:
pip install ray[tune]==2.7.0
For the newest version the Ray AIR session is replaced with a Ray Train context object.
You can import Checkpoint
using :
from ray.train import Checkpoint
You need to adjust your code as follow:
from ray import air, train
# Ray Train methods and classes:
air.session.report -> train.report
air.session.get_dataset_shard -> train.get_dataset_shard
air.session.get_checkpoint -> train.get_checkpoint
air.Checkpoint -> train.Checkpoint
air.Result -> train.Result
# Ray Train configurations:
air.config.CheckpointConfig -> train.CheckpointConfig
air.config.FailureConfig -> train.FailureConfig
air.config.RunConfig -> train.RunConfig
air.config.ScalingConfig -> train.ScalingConfig
# Ray TrainContext methods:
air.session.get_experiment_name -> train.get_context().get_experiment_name
air.session.get_trial_name -> train.get_context().get_trial_name
air.session.get_trial_id -> train.get_context().get_trial_id
air.session.get_trial_resources -> train.get_context().get_trial_resources
air.session.get_trial_dir -> train.get_context().get_trial_dir
air.session.get_world_size -> train.get_context().get_world_size
air.session.get_world_rank -> train.get_context().get_world_rank
air.session.get_local_rank -> train.get_context().get_local_rank
air.session.get_local_world_size -> train.get_context().get_local_world_size
air.session.get_node_rank -> train.get_context().get_node_rank
For more informations see :