Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Q-values exploding when training DQN

I'm training a DQN to play OpenAI's Atari environment, but the Q-values of my network quickly explode far above what is realistic.

Here's the relevant portion of the code:

for state, action, reward, next_state, done in minibatch:
        if not done:
            # To save on memory, next_state is just one frame
            # So we have to add it to the current state to get the actual input for the network
            next_4_states = np.array(state)
            next_4_states = np.roll(next_4_states, 1, axis=3)
            next_4_states[:, :, :, 0] = next_state
            target = reward + self.gamma * \
                np.amax(self.target_model.predict(next_4_states))
        else:
            target = reward
        target_f = self.target_model.predict(state)
        target_f[0][action] = target

        self.target_model.fit(state, target_f, epochs=1, verbose=0)

The discount factor is 0.99 (it doesn't happen with discount factor 0.9, but also doesn't converge because it can't think far enough ahead).

Stepping through the code, the reason it's happening is all the Q values that aren't meant to be updated (the ones for actions we didn't take) increase slightly. It's my understanding that passing the networks own output to the network during training should keep the output the same, not increase or decrease it. Is there something wrong with my model? Is there some way I can mask the update so it only updates the relevant Q value?

EDIT: My model creation code is here:

def create_model(self, input_shape, num_actions, learning_rate):
        model = Sequential()
        model.add(Convolution2D(32, 8, strides=(4, 4),
                                activation='relu', input_shape=(input_shape)))
        model.add(Convolution2D(64, 4, strides=(2, 2), activation='relu'))
        model.add(Convolution2D(64, 3, strides=(1, 1), activation='relu'))
        model.add(Flatten())
        model.add(Dense(512, activation='relu'))
        model.add(Dense(num_actions))

        model.compile(loss='mse', optimizer=Adam(lr=learning_rate))

        return model

I create two of these. One for the online network and one for the target.

like image 276
Omegastick Avatar asked Feb 21 '18 04:02

Omegastick


1 Answers

Which predictions get updated?

Stepping through the code, the reason it's happening is all the Q values that aren't meant to be updated (the ones for actions we didn't take) increase slightly. It's my understanding that passing the networks own output to the network during training should keep the output the same, not increase or decrease it.

Below I have drawn a very simple neural network with 3 input nodes, 3 hidden nodes, and 3 output nodes. Suppose that you have only set a new target for the first action, and simply use the existing predictions as targets again for the other actions. This results in only a non-zero (for simplicity I'll just assume greater than zero) error (denoted by delta in the image) for the first action/output, and errors of 0 for the others.

I have drawn the connections through which this error will be propagated from output to hidden layer in bold. Note how each of the nodes in the hidden layer still gets an error. When these nodes then propagate their errors back to the input layer, they'll do this through all of the connections between input and hidden layer, so all of those weights can be modified.

So, imagine all those weights got updated, and now imagine doing a new forwards pass with the original inputs. Do you expect output nodes 2 and 3 to have exactly the same outputs as before? No, probably not; the connections from hidden nodes to the last two outputs may still have the same weights, but all three hidden nodes will have different activation levels. So no, the other outputs are not guaranteed to remain the same.

Example Neural Network

Is there some way I can mask the update so it only updates the relevant Q value?

Not easily no, if at all. The problem is that the connections between pairs of layers other than the connections between the final pair are not action-specific, and I don't think you want them to be either.

Target Network

Is there something wrong with my model?

One thing I'm seeing is that you seem to be updating the same network that is used to generate targets:

target_f = self.target_model.predict(state)

and

self.target_model.fit(state, target_f, epochs=1, verbose=0)

both use self.target_model. You should use separate copies of the network for those two lines, and only after longer periods of time copy the updated network's weights into the network used to compute targets. For a bit more on this, see Addition 3 in this post.

Double DQN

Apart from that, it is well known that DQN can still have a tendency to overestimate Q values (though it generally shouldn't completely explode). This can be addressed by using Double DQN (note: this is an improvement that was added later on top of DQN).

like image 132
Dennis Soemers Avatar answered Oct 25 '22 14:10

Dennis Soemers