I am trying to integrate stable_baselines3 in dagshub and MlFlow. I am new to MLOPS
Here is a sample code that is easy to run:
import mlflow
import gym
from gym import spaces
import numpy as np
from stable_baselines3 import PPO
import os
os.environ['MLFLOW_TRACKING_USERNAME'] = "correct_dagshub_username"
os.environ['MLFLOW_TRACKING_PASSWORD'] = "correct_dagshub_token"
os.environ['MLFLOW_TRACKING_URI'] = "correct_URL")
# Create a simple custom gym environment
class SimpleEnv(gym.Env):
def __init__(self):
super(SimpleEnv, self).__init__()
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,))
def step(self, action):
return np.array([0, 0, 0, 0]), 0, False, {}
def reset(self):
return np.array([0, 0, 0, 0])
# Create and train the model
env = SimpleEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=1000)
# Save the model using MLflow
mlflow.log_artifact("model.zip")
# Load the model from MLflow using the captured run_id
run_id = mlflow.active_run().info.run_id
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
The problem is that I always get this error:
---------------------------------------------------------------------------
MlflowException Traceback (most recent call last)
Cell In[13], line 11
6 # Now the model is saved to MLflow with the corresponding run_id
7
8 # Step 5: Load the model from MLflow
9 run_id = mlflow.active_run().info.run_id
---> 11 loaded_model = mlflow.pytorch.load_model(f"runs:/{run_id}/model")
File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\pytorch\__init__.py:698, in load_model(model_uri, dst_path, **kwargs)
637 """
638 Load a PyTorch model from a local file or a run.
639
(...)
694 predict X: 30.0, y_pred: 60.48
695 """
696 import torch
--> 698 local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
699 pytorch_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
700 _add_code_from_conf_to_system_path(local_model_path, pytorch_conf)
File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\tracking\artifact_utils.py:100, in _download_artifact_from_uri(artifact_uri, output_path)
94 """
95 :param artifact_uri: The *absolute* URI of the artifact to download.
96 :param output_path: The local filesystem path to which to download the artifact. If unspecified,
97 a local output path will be created.
98 """
99 root_uri, artifact_path = _get_root_uri_and_artifact_path(artifact_uri)
--> 100 return get_artifact_repository(artifact_uri=root_uri).download_artifacts(
101 artifact_path=artifact_path, dst_path=output_path
102 )
File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\store\artifact\runs_artifact_repo.py:125, in RunsArtifactRepository.download_artifacts(self, artifact_path, dst_path)
110 def download_artifacts(self, artifact_path, dst_path=None):
111 """
112 Download an artifact file or directory to a local directory if applicable, and return a
113 local path for it.
(...)
123 :return: Absolute path of the local filesystem location containing the desired artifacts.
124 """
--> 125 return self.repo.download_artifacts(artifact_path, dst_path)
File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\store\artifact\artifact_repo.py:200, in ArtifactRepository.download_artifacts(self, artifact_path, dst_path)
197 failed_downloads[path] = repr(e)
199 if failed_downloads:
--> 200 raise MlflowException(
201 message=(
202 "The following failures occurred while downloading one or more"
203 f" artifacts from {self.artifact_uri}: {failed_downloads}"
204 )
205 )
207 return os.path.join(dst_path, artifact_path)
MlflowException: The following failures occurred while downloading one or more artifacts from URL/artifacts: {'model': 'MlflowException("API request to some api', port=443): Max retries exceeded with url: some_url (Caused by ResponseError(\'too many 500 error responses\'))")'}
Stable_baselines3 save the model as a zip file, I can see the artifact in MLflow but whatever I do cannot load the model from MLflow. I also tried it with
loaded_model = mlflow.pytorch.load_model(model_uri)
Any help would be greatly appreciated
When I ran your example, I got a different error:
Traceback (most recent call last):
File "/tmp/stable_baselines3/./train.py", line 36, in <module>
mlflow.pytorch.log_model(model, "model")
File "/tmp/stable_baselines3/.venv/lib/python3.11/site-packages/mlflow/pytorch/__init__.py", line 293, in log_model
return Model.log(
^^^^^^^^^^
File "/tmp/stable_baselines3/.venv/lib/python3.11/site-packages/mlflow/models/model.py", line 572, in log
flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
File "/tmp/stable_baselines3/.venv/lib/python3.11/site-packages/mlflow/pytorch/__init__.py", line 455, in save_model
raise TypeError("Argument 'pytorch_model' should be a torch.nn.Module")
TypeError: Argument 'pytorch_model' should be a torch.nn.Module
I am using gym==0.26.2
, mlflow==2.5.0
and stable-baselines3==2.0.0
on Python 3.11.3
. I think the error is a lot clearer in this case - PPO isn't a torch
model, and I couldn't find information on autologging stable_baselines3
models. So I set up a class through pyfunc
:
class PPOModelWrapper(mlflow.pyfunc.PythonModel):
def load_context(self, context):
self.model = PPO.load(context.artifacts["path"])
def predict(self, context, model_input):
action, states = self.model.predict(model_input)
return {"action": action, "states": states}
From there, you can log the model using mlflow.pyfunc.log_model
.
I've added the source code to the following repository: https://dagshub.com/jinensetpal/stable_baselines3, the logged model can be seen at: https://dagshub.com/jinensetpal/stable_baselines3.mlflow/#/experiments/0/runs/1f9e29528b5649b6a56a37ffb6a79a28/artifactPath/model
Hope this helps!