Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Mean or max pooling with masking support in Keras

...
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(size, return_sequences=True, dropout_W=0.2 dropout_U=0.2)) 
model.add(GlobalAveragePooling1D())
model.add(Dense(1))
model.add(Activation('sigmoid'))
....

I need to be able to take the mean or max of the vectors for all time steps in a sample after LSTM layer before giving this mean or max vector to the dense layer in Keras.

I think timedistributedmerge was able to do this but it was deprecated. Using return_sequences=True I can obtain the vectors for all time steps in a sample after the LSTM layer. However, GlobalAveragePooling1D() is not compatible with masking and it considers all time steps whereas I need only the non-masked time steps.

I saw posts recommending the Lambda layer but these also do not take masking into account. Any help would be appreciated.

like image 979
Ersin Avatar asked Sep 15 '16 12:09

Ersin


2 Answers

Jacoxu's answer is right. But if you are using a tensorflow backend for keras, the Tensor type doesn't support dimshuffle function, try this instead.

def call(self, x, mask=None):
    if mask is not None:
        # mask (batch, time)
        mask = K.cast(mask, K.floatx())
        # mask (batch, x_dim, time)
        mask = K.repeat(mask, x.shape[-1])
        # mask (batch, time, x_dim)
        mask = tf.transpose(mask, [0,2,1])
        x = x * mask
    return K.sum(x, axis=1) / K.sum(mask, axis=1)
like image 81
Geralt Xu Avatar answered Oct 11 '22 20:10

Geralt Xu


Since average pooling is only doing a mean over one axis, you just need to correct the number of elements in the mean since loss masking is handled at the end, not here. You can do this probably with something like this:

class GlobalAveragePooling1DMasked(GlobalAveragePooling1D):
    def call(self, x, mask=None):
        if mask != None:
            return K.sum(x, axis=1) / K.sum(mask, axis=1)
        else:
            return super().call(x)
like image 43
nemo Avatar answered Oct 11 '22 18:10

nemo