Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to change batch size of an intermediate layer in Keras?

Tags:

keras

My problem is to take all hidden outputs from an LSTM and use them as training examples for a single dense layer. Flattening the output of the hidden layers and feeding them to a dense layer is not what I am looking to do. I have tried the following things:

  1. I have considered Timedistributed wrapper for the dense layer (https://keras.io/layers/wrappers/). But, this seems to apply the same layer to every time slice, which is not what I want. In other words, the Timedistributed wrapper has input_shape of a 3D tensor (number of samples, number of timesteps, number of features) and produces another 3D tensor of the same type: (number of samples, number of timesteps, number of features). Instead what I want is a 2D tensor as output, which looks like (number of samples*number of timesteps, number of features)

  2. There was a pull request for an AdvancedReshapeLayer: https://github.com/fchollet/keras/pull/36 on GitHub. This seems to be exactly what I am looking for. Unfortunately, it appears like that pull request was closed with no conclusive outcome.

  3. I tried to build my own lambda layer to accomplish what I want as follows: A). model.add(LSTM(NUM_LSTM_UNITS, return_sequences=True, activation='tanh')) # B). model.add(Lambda(lambda x: x, output_shape=lambda x: (x[0]*x[1], x[2]))) C). model.add(Dense(NUM_CLASSES, input_dim=NUM_LSTM_UNITS))

mode.output_shape after (A) prints: (BATCH_SIZE, NUM_TIME_STEPS, NUM_LSTM_UNITS) and model.output_shape after (B) prints: (BATCH_SIZE*NUM_OF_TIMESTEPS, NUM_LSTM_UNITS)

Which is exactly what I am trying to achieve.

Unfortunately, when I try to run step (C). I get the following error:

Input 0 is incompatible with layer dense_1: expected ndim=2, found ndim=3

This is baffling since when I print model.output_shape after (B), I do indeed see (BATCH_SIZE*NUM_OF_TIMESTEPS, NUM_LSTM_UNITS), which is of ndim=2.

Really appreciate any help with this.

EDIT: When I try to use the functional API instead of a sequential model, I still get the same error on step (C)

like image 833
shoaib369 Avatar asked Dec 22 '16 21:12

shoaib369


1 Answers

You can use backend reshape which includes batch_size dimension.

def backend_reshape(x):
    return backend.reshape(x, (-1, NUM_LSTM_UNITS))

model.add(Lambda(backend_reshape, output_shape=(NUM_LSTM_UNITS,)))
like image 140
user3473614 Avatar answered Nov 01 '22 13:11

user3473614