Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract cell state from a LSTM at each timestep in Keras?

Is there a way in Keras to retrieve the cell state (i.e., c vector) of a LSTM layer at every timestep of a given input?

It seems the return_state argument returns the last cell state after the computation is done, but I need also the intermediate ones. Also, I don't want to pass these cell states to the next layer, I only want to be able to access them.

Preferably using TensorFlow as backend.

Thanks

like image 419
Leafar Avatar asked Aug 27 '18 03:08

Leafar


2 Answers

I was looking for a solution to this issue and after reading the guidance for creating your own custom RNN Cell in tf.keras (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AbstractRNNCell), I believe the following is the most concise and easy to read way of doing this for Tensorflow 2:

import tensorflow as tf
from tensorflow.keras.layers import LSTMCell

class LSTMCellReturnCellState(LSTMCell):

    def call(self, inputs, states, training=None):
        real_inputs = inputs[:,:self.units] # decouple [h, c]
        outputs, [h,c] = super().call(real_inputs, states, training=training)
        return tf.concat([h, c], axis=1), [h,c]



num_units = 512
test_input = tf.random.uniform([5,100,num_units])

rnn = tf.keras.layers.RNN(LSTMCellReturnCellState(num_units),
                          return_sequences=True, return_state=True)

whole_seq_output, final_memory_state, final_carry_state = rnn(test_input)

print(whole_seq_output.shape)
>>> (5,100,1024)

# Hidden state sequence
h_seq = whole_seq_output[:,:,:num_units] # (5,100,512)

# Cell state sequence
c_seq = whole_seq_output[:,:,num_units:] # (5,100,512)

As mentioned in an above solution, you can see the advantage of this is that it can be easily wrapped into tf.keras.layers.RNN as a drop-in for the normal LSTMCell.

Here is a Colab Notebook with the code running as expected for tensorflow==2.6.0

like image 80
Chau Luu Avatar answered Sep 17 '22 17:09

Chau Luu


First, this is not possible do with the tf.keras.layers.LSTM. You have to use LSTMCell instead or subclass LSTM. Second, there is no need to subclass LSTMCell to get the sequence of cell states. LSTMCell already returns a list of the hidden state (h) and cell state (c) everytime you call it. For those not familiar with LSTMCell, it takes in the current [h, c] tensors, and the input at the current timestep (it cannot take in a sequence of times) and returns the activations, and the updated [h,c]. Here is an example of showing how to use LSTMCell to process a sequence of timesteps and to return the accumulated cell states.

# example inputs
inputs = tf.convert_to_tensor(np.random.rand(3, 4), dtype='float32')  # 3 timesteps, 4 features
h_c = [tf.zeros((1,2)),  tf.zeros((1,2))]  # must initialize hidden/cell state for lstm cell
h_c = tf.convert_to_tensor(h_c, dtype='float32')
lstm = tf.keras.layers.LSTMCell(2)

# example of how you accumulate cell state over repeated calls to LSTMCell
inputs = tf.unstack(inputs, axis=0)
c_states = []
for cur_inputs in inputs:
    out, h_c = lstm(tf.expand_dims(cur_inputs, axis=0), h_c)
    h, c = h_c
    c_states.append(c)
like image 37
Taw Avatar answered Sep 19 '22 17:09

Taw