Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dict Observation Space for Stable Baselines3 Not Working

I've created a minimal reproducible example below, this can be run in a new Google Colab notebook for ease. Once the first install finishes, just Runtime > Restart and Run All for it to take effect.

I've made a simple roulette game environment below for testing. For the observation space, I've created a gym.spaces.Dict() which you will see (the code is well commented).

It trains just fine, but when it gets to the testing iteration, I get the error:

ValueError                                Traceback (most recent call last)
<ipython-input-56-7c2cb900b44f> in <module>
      6 obs = env.reset()
      7 for i in range(1000):
----> 8     action, _state = model.predict(obs, deterministic=True)
      9     obs, reward, done, info = env.step(action)
     10     env.render()

ValueError: Error: Unexpected observation shape () for Box environment, please use (1,) or (n_env, 1) for the observation shape.

I read somewhere that the dict space needs to be flattened with gym.wrappers.FlattenObservation, so I change this line:

    action, _state = model.predict(obs, deterministic=True)

...to:

    action, _state = model.predict(FlattenObservation(obs), deterministic=True)

...which results in this error:

AttributeError                            Traceback (most recent call last)
<ipython-input-57-87824c61fc45> in <module>
      6 obs = env.reset()
      7 for i in range(1000):
----> 8     action, _state = model.predict(FlattenObservation(obs), deterministic=True)
      9     obs, reward, done, info = env.step(action)
     10     env.render()

AttributeError: 'collections.OrderedDict' object has no attribute 'observation_space'

I've also tried doing this, which results in the same error as the last one:

obs = env.reset()
obs = FlattenObservation(obs)

So clearly I'm not doing something right, but I just don't know what it is as this'll be the first time I'm working with a Dict space.

import os, sys
if not os.path.isdir('/usr/local/lib/python3.7/dist-packages/stable_baselines3'):
    !pip3 install stable_baselines3
    print("\n\n\n Stable Baselines3 has been installed, Restart and Run All now. DO NOT factory reset, or you'll have to start over\n")
    sys.exit(0)

from random import randint
from numpy import inf, float32, array, int32, int64
import gym
from gym.wrappers import FlattenObservation
from stable_baselines3 import A2C, DQN, PPO

"""Roulette environment class"""
class Roulette_Environment(gym.Env):

    metadata = {'render.modes': ['human', 'text']}

    """Initialize the environment"""
    def __init__(self):
        super(Roulette_Environment, self).__init__()

        # Some global variables
        self.max_table_limit = 1000
        self.initial_bankroll = 2000

        # Spaces
        # Each number on roulette board can have 0-1000 units placed on it
        self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,))

        # We're going to keep track of how many times each number shows up
        # while we're playing, plus our current bankroll and the max
        # table betting limit so the agent knows how much $ in total is allowed
        # to be placed on the table. Going to use a Dict space for this.
        self.observation_space = gym.spaces.Dict(
            {
                "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                
                "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
                
                "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
            }
        )

    """Reset the Environment"""
    def reset(self):
        self.current_bankroll = self.initial_bankroll
        self.done = False

        # Take a sample from the observation_space to modify the values of
        self.current_state = self.observation_space.sample()
        
        # Reset each number being tracked throughout gameplay to 0
        for i in range(0, 37):
            self.current_state[str(i)] = 0

        # Reset our globals
        self.current_state['current_bankroll'] = self.current_bankroll
        self.current_state['max_table_limit'] = self.max_table_limit
        
        return self.current_state


    """Step Through the Environment"""
    def step(self, action):
        
        # Convert actions to ints cuz they show up as floats,
        # even when defined as ints in the environment.
        # https://github.com/openai/gym/issues/3107
        for i in range(len(action)):
            action[i] = int(action[i])
        self.current_action = action
        
        # Subtract your bets from bankroll
        sum_of_bets = sum([bet for bet in self.current_action])

        # Spin the wheel
        self.current_number = randint(a=0, b=36)

        # Calculate payout/reward
        self.reward = 36 * self.current_action[self.current_number] - sum_of_bets

        self.current_bankroll += self.reward

        # Update the current state
        self.current_state['current_bankroll'] = self.current_bankroll
        self.current_state[str(self.current_number)] += 1

        # If we've doubled our money, or lost our money
        if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
            self.done = True

        return self.current_state, self.reward, self.done, {}


    """Render the Environment"""
    def render(self, mode='text'):
        # Text rendering
        if mode == "text":
            print(f'Bets Placed: {self.current_action}')
            print(f'Number rolled: {self.current_number}')
            print(f'Reward: {self.reward}')
            print(f'New Bankroll: {self.current_bankroll}')

env = Roulette_Environment()

model = PPO('MultiInputPolicy', env, verbose=1)
model.learn(total_timesteps=10000)

obs = env.reset()
# obs = FlattenObservation(obs)

for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    # action, _state = model.predict(FlattenObservation(obs), deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()
like image 541
Matt Wilson Avatar asked Dec 05 '25 17:12

Matt Wilson


2 Answers

Unfortunately, stable-baselines3 is pretty picky about the observation format.
I ran into the same problem the last days.
Some documentation as well as an example model helped me figure things out:

It is possible to use Dict-observations

However, the values of Box-Spaces must be mapped as numpy.ndarrays with correct dtypes.
For Discrete observations, the observation can also be passed as int value. However, I'm not completely sure if this still holds for multidimensional MultiDiscrete-spaces

A very simple solution

A solution to your example would be to replace code everytime you reassign a value of your Dict through:
self.current_state[key] = np.array([value], dtype=int)

Below you find a working implementation of your problem (my system has Python=3.10 installed, though. But it should work on lower versions as well).

Working code:

import os, sys

from random import randint
from numpy import inf, float32, array, int32, int64
import gym
from gym.wrappers import FlattenObservation
from stable_baselines3 import A2C, DQN, PPO
import numpy as np

"""Roulette environment class"""
class Roulette_Environment(gym.Env):

    metadata = {'render.modes': ['human', 'text']}

    """Initialize the environment"""
    def __init__(self):
        super(Roulette_Environment, self).__init__()

        # Some global variables
        self.max_table_limit = 1000
        self.initial_bankroll = 2000

        # Spaces
        # Each number on roulette board can have 0-1000 units placed on it
        self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,))

        # We're going to keep track of how many times each number shows up
        # while we're playing, plus our current bankroll and the max
        # table betting limit so the agent knows how much $ in total is allowed
        # to be placed on the table. Going to use a Dict space for this.
        self.observation_space = gym.spaces.Dict(
            {
                "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                
                "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
                
                "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
            }
        )

    """Reset the Environment"""
    def reset(self):
        self.current_bankroll = self.initial_bankroll
        self.done = False

        # Take a sample from the observation_space to modify the values of
        self.current_state = self.observation_space.sample()
        
        # Reset each number being tracked throughout gameplay to 0
        for i in range(0, 37):
            self.current_state[str(i)] = np.array([0], dtype=int)

        # Reset our globals
        self.current_state['current_bankroll'] = np.array([self.current_bankroll], dtype=int)
        self.current_state['max_table_limit'] = np.array([self.max_table_limit], dtype=int)
        
        return self.current_state


    """Step Through the Environment"""
    def step(self, action):
        
        # Convert actions to ints cuz they show up as floats,
        # even when defined as ints in the environment.
        # https://github.com/openai/gym/issues/3107
        for i in range(len(action)):
            action[i] = int(action[i])
        self.current_action = action
        
        # Subtract your bets from bankroll
        sum_of_bets = sum([bet for bet in self.current_action])

        # Spin the wheel
        self.current_number = randint(a=0, b=36)

        # Calculate payout/reward
        self.reward = 36 * self.current_action[self.current_number] - sum_of_bets

        self.current_bankroll += self.reward

        # Update the current state
        self.current_state['current_bankroll'] = np.array([self.current_bankroll], dtype=int)
        self.current_state[str(self.current_number)] += np.array([1], dtype=int)

        # If we've doubled our money, or lost our money
        if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
            self.done = True

        return self.current_state, self.reward, self.done, {}


    """Render the Environment"""
    def render(self, mode='text'):
        # Text rendering
        if mode == "text":
            print(f'Bets Placed: {self.current_action}')
            print(f'Number rolled: {self.current_number}')
            print(f'Reward: {self.reward}')
            print(f'New Bankroll: {self.current_bankroll}')

env = Roulette_Environment()

model = PPO('MultiInputPolicy', env, verbose=1)
model.learn(total_timesteps=10)

obs = env.reset()
# obs = FlattenObservation(obs)

for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    # action, _state = model.predict(FlattenObservation(obs), deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()
like image 108
ttronas Avatar answered Dec 08 '25 08:12

ttronas


You have 3 different issues here.

First, your main problem is in reset method. You've defined your items as boxes with a shape=(1,). However, in reset you assign simply integers to your items, e.g., here self.current_state[str(i)] = 0 and later for current_bankroll and max_table_limit keys. SB3's BasePolicy's predict wraps your dict values with np.array(your_integer_value) that has a shape (), which ofc raises an exception since it is incompatible with your boxes shape. Change your initial values to 1-size arrays respectively, e.g., self.current_state[str(i)] = [0]. Also change your step method to update 1-size lists, not integers. That will solve your issue with shape inadequacy.

Second, you can actually get rid off Dict by manually flattening all your single-shape Boxes into a single one. Your low will turn into a list respectively (if you change your current_bankroll value's low to 0, then you even don't need to edit low as well, it can be further an integer).

Third, other than above mentioned your env looks correct. However, there's a bug in sb3. I assume you've installed sb3 with pip using the latest 1.6.2 tag (10th October). In this version there is a bug restricting a type of observation in BaseAlgortihm.predict to np.ndarray only, which was later fixed in the master branch. So install sb3 from a git directly.

like image 42
gehirndienst Avatar answered Dec 08 '25 07:12

gehirndienst



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!