i have followed the steps in https://machinelearningmastery.com/return-sequences-and-return-states-for-lstms-in-keras/ But when it comes to the Bidirectional lstm, i tried this
lstm, state_h, state_c = Bidirectional(LSTM(128, return_sequences=True, return_state= True))(input)
but it won't work.
is there some approach to get both the final hidden state and sequence in a LSTM layer when using a bidirectional wrapper
With a Bi-Directional LSTM, the final outputs are now a concatenation of the forwards and backwards directions. This is where it gets a little complicated, as the two directions will have seen different inputs for each output.
Bidirectional LSTMs are an extension of traditional LSTMs that can improve model performance on sequence classification problems. In problems where all timesteps of the input sequence are available, Bidirectional LSTMs train two instead of one LSTMs on the input sequence.
Creating a layer of LSTM memory units allows you to specify the number of memory units within the layer. Each unit or cell within the layer has an internal cell state, often abbreviated as “c“, and outputs a hidden state, often abbreviated as “h“.
First of all the second layer won't have the output shape of 64 , but instead of 128 . This is because you are using Bidirectional layer, it will be concatenated by a forward and backward pass and so you output will be (None, None, 64+64=128) .
The call Bidirectional(LSTM(128, return_sequences=True, return_state=True))(input)
returns 5 tensors:
h
for the forward LSTMc
for the forward LSTMh
for the backward LSTMc
for the backward LSTMThe line you've posted would raise an error since you want to unpack the returned value into just three variables (lstm, state_h, state_c
).
To correct it, simply unpack the returned value into 5 variables. If you want to merge the states, you can concatenate the forward and backward states with Concatenate
layers.
lstm, forward_h, forward_c, backward_h, backward_c = Bidirectional(LSTM(128, return_sequences=True, return_state=True))(input)
state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With