pythonmachine-learningreinforcement-learningopenai-gymstable-baselines

Getting a very simple stablebaselines3 example to work


I tried to model the simplest coin flipping game where you have to predict if it is going to be a head. Sadly it won't run, given me:

Using cpu device
Traceback (most recent call last):
  File "/home/user/python/simplegame.py", line 40, in <module>
    model.learn(total_timesteps=10000)
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/ppo/ppo.py", line 315, in learn
    return super().learn(
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 264, in learn
    total_timesteps, callback = self._setup_learn(
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/base_class.py", line 423, in _setup_learn
    self._last_obs = self.env.reset()  # type: ignore[assignment]
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
    obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
TypeError: CoinFlipEnv.reset() got an unexpected keyword argument 'seed'

Here is the code:

import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

class CoinFlipEnv(gym.Env):
    def __init__(self, heads_probability=0.8):
        super(CoinFlipEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(2)  # 0 for heads, 1 for tails
        self.observation_space = gym.spaces.Discrete(2)  # 0 for heads, 1 for tails
        self.heads_probability = heads_probability
        self.flip_result = None

    def reset(self):
        # Reset the environment
        self.flip_result = None
        return self._get_observation()

    def step(self, action):
        # Perform the action (0 for heads, 1 for tails)
        self.flip_result = int(np.random.rand() < self.heads_probability)

        # Compute the reward (1 for correct prediction, -1 for incorrect)
        reward = 1 if self.flip_result == action else -1

        # Return the observation, reward, done, and info
        return self._get_observation(), reward, True, {}

    def _get_observation(self):
        # Return the current coin flip result
        return self.flip_result

# Create the environment with heads probability of 0.8
env = DummyVecEnv([lambda: CoinFlipEnv(heads_probability=0.8)])

# Create the PPO model
model = PPO("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=10000)

# Save the model
model.save("coin_flip_model")

# Evaluate the model
obs = env.reset()
for _ in range(10):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    print(f"Action: {action}, Observation: {obs}, Reward: {rewards}")

What am I doing wrong?

This is in version 2.2.1.


Solution

  • The gymnasium.Env class has the following signature which divers from the one by DummyVecEnv which takes no arguments.

    Env.reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) → tuple[ObsType, dict[str, Any]]

    in other words seed and options are keyword-only which your own reset function needs to implement. It returns the observation, info tuple in the end.

    The problems to note:

     def reset(self, *, seed=None, options=None): # Fix input signature
            # Reset the environment
            self.flip_result = 0 # None is not a valid Observation
            return self.flip_result, {} # Fix return signature
    

    If you return None, as underlying numpy arrays are used array([0])[0]=obs <- None would throw another error.


    step needs to have five returns parameters observation, reward, terminated, truncated, info

        def step(self, action):
            # Perform the action (0 for heads, 1 for tails)
            self.flip_result = int(np.random.rand() < self.heads_probability)
    
            # Compute the reward (1 for correct prediction, -1 for incorrect)
            reward = 1 if self.flip_result == action else -1
    
            # Return the observation, reward, done, truncated, and info
            return self._get_observation(), reward, True, False, {}
    

    Now the models trains:

    -----------------------------
    | time/              |      |
    |    fps             | 5608 |
    |    iterations      | 1    |
    |    time_elapsed    | 0    |
    |    total_timesteps | 2048 |
    -----------------------------
    -----------------------------------------
    | time/                   |             |
    |    fps                  | 3530        |
    |    iterations           | 2           |
    |    time_elapsed         | 1           |
    |    total_timesteps      | 4096        |
    | train/                  |             |
    |    approx_kl            | 0.020679139 |
    |    clip_fraction        | 0.617       |
    |    clip_range           | 0.2         |
    |    entropy_loss         | -0.675      |
    |    explained_variance   | 0           |
    |    learning_rate        | 0.0003      |
    |    loss                 | 0.38        |
    |    n_updates            | 10          |
    |    policy_gradient_loss | -0.107      |
    |    value_loss           | 1           |
    -----------------------------------------
    -----------------------------------------
    | time/                   |             |
    |    fps                  | 3146        |
    |    iterations           | 3           |
    |    time_elapsed         | 1           |
    |    total_timesteps      | 6144        |
    | train/                  |             |
    |    approx_kl            | 0.032571375 |
    |    clip_fraction        | 0.628       |
    |    clip_range           | 0.2         |
    |    entropy_loss         | -0.599      |
    |    explained_variance   | 0           |
    |    learning_rate        | 0.0003      |
    |    loss                 | 0.392       |
    |    n_updates            | 20          |
    |    policy_gradient_loss | -0.104      |
    |    value_loss           | 0.987       |
    -----------------------------------------
    ---------------------------------------
    | time/                   |           |
    |    fps                  | 2984      |
    |    iterations           | 4         |
    |    time_elapsed         | 2         |
    |    total_timesteps      | 8192      |
    | train/                  |           |
    |    approx_kl            | 0.0691616 |
    |    clip_fraction        | 0.535     |
    |    clip_range           | 0.2       |
    |    entropy_loss         | -0.417    |
    |    explained_variance   | 0         |
    |    learning_rate        | 0.0003    |
    |    loss                 | 0.335     |
    |    n_updates            | 30        |
    |    policy_gradient_loss | -0.09     |
    |    value_loss           | 0.941     |
    ---------------------------------------
    ----------------------------------------
    | time/                   |            |
    |    fps                  | 2898       |
    |    iterations           | 5          |
    |    time_elapsed         | 3          |
    |    total_timesteps      | 10240      |
    | train/                  |            |
    |    approx_kl            | 0.12130852 |
    |    clip_fraction        | 0.125      |
    |    clip_range           | 0.2        |
    |    entropy_loss         | -0.189     |
    |    explained_variance   | 0          |
    |    learning_rate        | 0.0003     |
    |    loss                 | 0.536      |
    |    n_updates            | 40         |
    |    policy_gradient_loss | -0.0397    |
    |    value_loss           | 0.806      |
    ----------------------------------------
    Action: [1], Observation: [0], Reward: [1.]
    Action: [1], Observation: [0], Reward: [-1.]
    Action: [1], Observation: [0], Reward: [-1.]
    Action: [1], Observation: [0], Reward: [1.]
    Action: [1], Observation: [0], Reward: [1.]
    Action: [1], Observation: [0], Reward: [-1.]
    Action: [1], Observation: [0], Reward: [1.]
    Action: [1], Observation: [0], Reward: [-1.]
    Action: [1], Observation: [0], Reward: [1.]
    Action: [1], Observation: [0], Reward: [1.]