Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

keras stateful LSTM

Tags:

python

keras

lstm

Please consider this simple example

nb_samples = 100000
X = np.random.randn(nb_samples)
Y = X[1:]
X = X[:-1]
X = X.reshape((len(Y), 1, 1))
Y = Y.reshape((len(Y), 1))

So we have basically

Y[i] = X[i-1]

and the model is simply a lag operator.

I can learn this model with a stateless LSTM, but I want here to understand and apply stateful LSTMs in Keras.

So I try to learn this model with a stateful LSTM, by giving the pairs of values (x, y) one by one (batch_size = 1)

model = Sequential()
model.add(LSTM(batch_input_shape=(1, 1, 1),
               output_dim =10,
               activation='tanh', stateful=True
          )
    )
model.add(Dense(output_dim=1, activation='linear'))
model.compile(loss='mse', optimizer='adam')


for epoch in range(50):
    model.fit(X_train,
              Y_train,
              nb_epoch = 1,
              verbose = 2,
              batch_size = 1,
              shuffle = False)
    model.reset_states()

But the model does not learn anything.

As per Marcin suggestion, I modified the training code as follows:

for epoch in range(10000):
    model.reset_states()
    train_loss = 0
    for i in range(Y_train.shape[0]):
        train_loss += model.train_on_batch(X_train[i:i+1],
                         Y_train[i:i+1],
                         )
    print '# epoch', epoch, '  loss ', train_loss/float(Y_train.shape[0])

but I am still seeing a mean loss around 1, which is the standard deviation of my randomly generated data, so the model does not seem to learn.

Am I having something wrong?

like image 490
volatile Avatar asked Nov 07 '22 23:11

volatile


1 Answers

As you may read here even though your model states are not reset due to statefulness of your network - a parameters of your optimizer are - and due to the fact that optimizers are extremely important in recurrent neural network training - resetting their states might be extremely harmful for your training. In order to prevent that try:

for epoch in range(50):
    model.train_on_batch(X_train,
              Y_train)
    model.reset_states()

A train_on_batch method doesn't reset your optimizer states what could make your training possible.

like image 98
Marcin Możejko Avatar answered Dec 02 '22 07:12

Marcin Możejko