Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Passing initial_state to Bidirectional RNN layer in Keras

I'm trying to implement encoder-decoder type network in Keras, with Bidirectional GRUs.

The following code seems to be working

src_input = Input(shape=(5,))
ref_input = Input(shape=(5,))

src_embedding = Embedding(output_dim=300, input_dim=vocab_size)(src_input)
ref_embedding = Embedding(output_dim=300, input_dim=vocab_size)(ref_input)

encoder = Bidirectional(
                GRU(2, return_sequences=True, return_state=True)
        )(src_embedding)

decoder = GRU(2, return_sequences=True)(ref_embedding, initial_state=encoder[1])

But when I change the decode to use Bidirectional wrapper, it stops showing encoder and src_input layers in the model.summary(). The new decoder looks like:

decoder = Bidirectional(
                GRU(2, return_sequences=True)
        )(ref_embedding, initial_state=encoder[1:])

The output of model.summary() with the Bidirectional decoder.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 5)                 0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 5, 300)            6610500   
_________________________________________________________________
bidirectional_2 (Bidirection (None, 5, 4)              3636      
=================================================================
Total params: 6,614,136
Trainable params: 6,614,136
Non-trainable params: 0
_________________________________________________________________

Question: Am I missing something when I pass initial_state in Bidirectional decoder? How can I fix this? Is there any other way to make this work?

like image 241
nisargjhaveri Avatar asked Nov 08 '22 11:11

nisargjhaveri


1 Answers

It's a bug. The RNN layer implements __call__ so that tensors in initial_state can be collected into a model instance. However, the Bidirectional wrapper did not implement it. So topological information about the initial_state tensors is missing and some strange bugs happen.

I wasn't aware of it when I was implementing initial_state for Bidirectional. It should be fixed now, after this PR. You can install the latest master branch on GitHub to fix it.

like image 180
Yu-Yang Avatar answered Nov 30 '22 08:11

Yu-Yang