Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to change batch size of an intermediate layer in Keras?



My problem is to take all hidden outputs from an LSTM and use them as training examples for a single dense layer. Flattening the output of the hidden layers and feeding them to a dense layer is not what I am looking to do. I have tried the following things:

  1. I have considered Timedistributed wrapper for the dense layer (https://keras.io/layers/wrappers/). But, this seems to apply the same layer to every time slice, which is not what I want. In other words, the Timedistributed wrapper has input_shape of a 3D tensor (number of samples, number of timesteps, number of features) and produces another 3D tensor of the same type: (number of samples, number of timesteps, number of features). Instead what I want is a 2D tensor as output, which looks like (number of samples*number of timesteps, number of features)

  2. There was a pull request for an AdvancedReshapeLayer: https://github.com/fchollet/keras/pull/36 on GitHub. This seems to be exactly what I am looking for. Unfortunately, it appears like that pull request was closed with no conclusive outcome.

  3. I tried to build my own lambda layer to accomplish what I want as follows: A). model.add(LSTM(NUM_LSTM_UNITS, return_sequences=True, activation='tanh')) # B). model.add(Lambda(lambda x: x, output_shape=lambda x: (x[0]*x[1], x[2]))) C). model.add(Dense(NUM_CLASSES, input_dim=NUM_LSTM_UNITS))

mode.output_shape after (A) prints: (BATCH_SIZE, NUM_TIME_STEPS, NUM_LSTM_UNITS) and model.output_shape after (B) prints: (BATCH_SIZE*NUM_OF_TIMESTEPS, NUM_LSTM_UNITS)

Which is exactly what I am trying to achieve.

Unfortunately, when I try to run step (C). I get the following error:

Input 0 is incompatible with layer dense_1: expected ndim=2, found ndim=3

This is baffling since when I print model.output_shape after (B), I do indeed see (BATCH_SIZE*NUM_OF_TIMESTEPS, NUM_LSTM_UNITS), which is of ndim=2.

Really appreciate any help with this.

EDIT: When I try to use the functional API instead of a sequential model, I still get the same error on step (C)

like image 833
shoaib369 Avatar asked Dec 22 '16 21:12


1 Answers

You can use backend reshape which includes batch_size dimension.

def backend_reshape(x):
    return backend.reshape(x, (-1, NUM_LSTM_UNITS))

model.add(Lambda(backend_reshape, output_shape=(NUM_LSTM_UNITS,)))
like image 140
user3473614 Avatar answered Nov 01 '22 13:11
