Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In Keras when does LSTM state reset in the call to model.predict?

Tags:

keras

lstm

The model has an LSTM as its first layer.

When calling model.predict say you pass in several samples:

>sam = np.array([ [[.5, .6, .3]], [[.6, .6, .3]], [[.5, .6, .3]] ])
>model.predict(sam)
array([[ 0.23589483],
       [ 0.2327884 ],
       [ 0.23589483]])

Above we see the mapping: [[.5, .6, .3]] -> 0.23589483 etc (sequence of 1 element, which is a vector of length 3, is mapped to a real number)

The model has input_length of 1 and input_dim of 3. Notice that the first and last are the same and have the same output (0.23589483). So my assumption is that after Keras processes a sample (in this case a sequence of 1 3-D vector), it resets the model's memory. That is each sequence is basically independent. Is there anything incorrect or misleading in this view?

Take another example with input_length 3 and input_dim 1. This time, switch values within a sequence and see a different result (compare the second list to the last). So memory is changing as Keras processes a sequence but when done processing, the memory resets (first and second sequences have the same result).

sam = np.array([ [[.1],[.1],[.9]], [[.1],[.9],[.1]], [[.1],[.1],[.9]]   ])
model.predict(sam)
array([[ 0.69906837],
   [ 0.1454899 ],
   [ 0.69906837]])

Above we see the mapping [[.1],[.1],[.9]] -> 0.69906837 etc (sequence of 3 elements to a real number)

like image 617
yalis Avatar asked Aug 29 '16 00:08

yalis


2 Answers

I appreciate this is an old question, but hope that this answer can help other Keras beginners like me.

I ran this example on my machine and observed that the hidden states and cell states of the LSTM were indeed changing with a call to model.predict.

import numpy as np
import keras.backend as K
from keras.models import Model
from keras.layers import LSTM

batch_size = 1
timestep_size = 2
num_features = 4

inputs = Input(batch_shape=(batch_size, timestep_size, num_features)
x = LSTM(num_features, stateful=True)(inputs)

model = Model(inputs=inputs, outputs=x)
model.compile(loss="mse",
              optimizer="rmsprop",
              metrics=["accuracy"])

x = np.random.randint((10,2,4))
y = np.ones((10,4))
model.fit(x,y, epochs=100, batch_size=1)

def get_internal_state(model):
    # get the internal state of the LSTM
    # see https://github.com/fchollet/keras/issues/218
    h, c = [K.get_value(s) for s, _ in model.state_updates]
    return h, c

print "After fitting:", get_internal_state(model)

for i in range(3):
    x = np.random.randint((10,2,4))
    model.predict(x)
    print "After predict:", get_internal_state(model)

Here's a sample of the output of the the calls to get_internal_state after training:

After_fitting: (array([[ 1.,  1.,  1.,  1.]], dtype=float32), array([[  11.33725166,   11.8036108 ,  181.75688171,   25.50110626]], dtype=float32))
After predict (array([[ 1.        ,  0.99999994,  1.        ,  1.        ]], dtype=float32), array([[   9.26870918,    8.83847237,  179.92633057,   28.89341927]], dtype=float32))
After predict (array([[ 0.99999571,  0.9992013 ,  1.        ,  0.9915328 ]], dtype=float32), array([[   6.5174489 ,    8.55165958,  171.42166138,   25.49199104]], dtype=float32))
After predict (array([[ 1.,  1.,  1.,  1.]], dtype=float32), array([[   9.78496075,    9.27927303,  169.95401001,   28.74017715]], dtype=float32))
like image 52
Geraint Ballinger Avatar answered Oct 28 '22 00:10

Geraint Ballinger


You are calling model.predict() this means weights of network wont change when processes an input so when you input [[.1],[.1],[.9]] it will always produce the same result no matter what other inputs receives in between. Note that this is the preferred behavior when you've trained your model and are predicting for test data. You don't want the other you feed test data to affect your prediction.

The effect you are expecting here is seen in model.fit() for example you can use model.train_on_batch() to train on an input (and update model weights) and then call model.predict() to see the output changing.

EDIT: if your looking for state of LSTM rather than weights of network you should pass stateful=True to layer's init, it is set to False by default. When using stateful you have to pass batch_input_shape argument too. See here for more info. Note that if you want each input to affect next prediction you have to set batch size to 1 (e.g. batch_input_shape=(1,3,1)) because evaluation is done in parallel for samples in a batch and they don't affect each other.

like image 37
S.Mohsen sh Avatar answered Oct 28 '22 00:10

S.Mohsen sh