Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

ValueError: Error when checking input: expected flatten_input to have shape (1, 4) but got array with shape (1, 2)

I'm fairly new to RL and i can't really understand why I'm getting this error.

import random
import numpy as np
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
from rl.agents import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory



def build_model(states, actions):
    model = Sequential()
    model.add(Flatten(input_shape=(1,states)))
    model.add(Dense(24, activation='relu'))
    model.add(Dense(24, activation='relu'))
    model.add(Dense(actions, activation='linear'))
    return model

def build_agent(model, actions):
    policy = BoltzmannQPolicy()
    memory = SequentialMemory(limit=50000, window_length=1)
    dqn = DQNAgent(model=model, memory=memory, policy=policy, 
                  nb_actions=actions, nb_steps_warmup=10, target_model_update=1e-2)
    return dqn

def main():

    env = gym.make('CartPole-v1')
    states = env.observation_space.shape[0]
    actions = env.action_space.n
    #print(env.reset())
    #print(env.action_space.sample())
    print(env.observation())

    model = build_model(states, actions)


    dqn = build_agent(model, actions)
    dqn.compile(Adam(learning_rate=1e-3),metrics=['mae'])
    dqn.fit(env, nb_steps=50000, visualize=False, verbose=1)


main()

I can't understand why it's getting an array with shape (1,2). I've looked through some people's similar questions but I can't apply it to mine. It starts the training but it fails immediately with 0 steps performed.

Thanks in advance!

like image 874
Pedro Carvalho Avatar asked Dec 04 '25 14:12

Pedro Carvalho


1 Answers

So I got it Working this way,First of all, go Inside the core.py file.Then,

  1. Initialise the value of first_time in the init(self) of the Agent class

self.first_time = True

  1. In my case im updating the fit(), so go to the fit() and replace

action = self.forward(observation)

with

if self.first_time:
    action = self.forward(observation[0])
    self.first_time = False
else:
    action = self.forward(observation)
  1. Inside the while loop we were working in, under

if observation is None:

set the value of first_time as True

if observation is None:  # start of a new episode
    self.first_time = True 
  1. Inside the same loop, if u go a bit down u can see

observation, r, done, info = env.step(action)

Replace it with,

observation, r, done,trunc, info = env.step(action)

And that did the job for me. hope that helps :)

like image 145
Lewlin Antony Avatar answered Dec 07 '25 08:12

Lewlin Antony



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!