Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

LSTM Followed by Mean Pooling

I'm using Keras 1.0. My problem is identical to this one (How to implement a Mean Pooling layer in Keras), but the answer there does not seem to be sufficient for me.

I want to implement this network: enter image description here

The following code does not work:

sequence = Input(shape=(max_sent_len,), dtype='int32')
embedded = Embedding(vocab_size, word_embedding_size)(sequence)
lstm = LSTM(hidden_state_size, activation='sigmoid', inner_activation='hard_sigmoid', return_sequences=True)(embedded)
pool = AveragePooling1D()(lstm)
output = Dense(1, activation='sigmoid')(pool)

If I don't set return_sequences=True, I get this error when I call AveragePooling1D():

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/PATH/keras/engine/topology.py", line 462, in __call__
    self.assert_input_compatibility(x)
  File "/PATH/keras/engine/topology.py", line 382, in assert_input_compatibility
    str(K.ndim(x)))
Exception: ('Input 0 is incompatible with layer averagepooling1d_6: expected ndim=3', ' found ndim=2')

Otherwise, I get this error when I call Dense():

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/PATH/keras/engine/topology.py", line 456, in __call__
    self.build(input_shapes[0])
  File "/fs/clip-arqat/mossaab/trec/liveqa/cmu/venv/lib/python2.7/site-packages/keras/layers/core.py", line 512, in build
    assert len(input_shape) == 2
AssertionError
like image 479
mossaab Avatar asked Apr 05 '16 13:04

mossaab


2 Answers

I just attempted to implement the same model as the original poster, and I'm using Keras 2.0.3. The mean pooling after LSTM worked when I used GlobalAveragePooling1D, just make sure return_sequences=True in the LSTM layer. Give it a try!

like image 141
jypucca Avatar answered Sep 22 '22 09:09

jypucca


Adding TimeDistributed(Dense(1)) helped:

sequence = Input(shape=(max_sent_len,), dtype='int32')
embedded = Embedding(vocab_size, word_embedding_size)(sequence)
lstm = LSTM(hidden_state_size, activation='sigmoid', inner_activation='hard_sigmoid', return_sequences=True)(embedded)
distributed = TimeDistributed(Dense(1))(lstm)
pool = AveragePooling1D()(distributed)
output = Dense(1, activation='sigmoid')(pool)
like image 21
mossaab Avatar answered Sep 20 '22 09:09

mossaab