Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dimension Mismatch in LSTM Keras

I want to create a basic RNN that can add two bytes. Here are the input and outputs, which are expected of a simple addition

X = [[0, 0], [0, 1], [1, 1], [0, 1], [1, 0], [1, 0], [1, 1], [1, 0]]

That is, X1 = 00101111 and X2 = 01110010

Y = [1, 0, 1, 0, 0, 0, 0, 1]

I created the following sequential model

model = Sequential()
model.add(GRU(output_dim = 16, input_length = 2, input_dim = 8))
model.add(Activation('relu'`))
model.add(Dense(2, activation='softmax'))
model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
model.summary()

The error I get is something along

expected lstm_input_1 to have 3 dimensions, but got array with shape (8L, 2L)

So if I increase the dimensions by changing X to

[[[0 0]] [[1 1]] [[1 1]] [[1 0]] [[0 0]] [[1 0]] [[0 1]] [[1 0]]]

Then the error changes to

expected lstm_input_1 to have shape (None, 8, 2) but got array with shape (8L, 1L, 2L)

like image 880
Yesh Avatar asked Jun 29 '16 17:06

Yesh


3 Answers

In Keras the Sequential models expect an input of shape (batch_size, sequence_length, input_dimension). I suspect you need to change the two last dimensions of your input array. Remember, the batch dimension is not explicitly defined.

like image 198
Simon Kamronn Avatar answered Nov 03 '22 17:11

Simon Kamronn


Change X to [[[0, 0], [0, 1], [1, 1], [0, 1], [1, 0], [1, 0], [1, 1], [1, 0]]] so that its shape is (1, 8, 2)

like image 41
sytrus Avatar answered Nov 03 '22 18:11

sytrus


Keras as input requiers 3D data, as stated in error. It is samples, time steps, features. Since you have (8L, 2L) Keras takes it as 2D - [samples, features]. In order to fix it, do something like this

def reshape_dataset(train):
    trainX = numpy.reshape(train, (train.shape[0], 1, train.shape[1]))
    return numpy.array(trainX)

x = reshape_dataset(your_dataset)

now X should be 8L,1,2L which is [samples, time steps, features] - 3D

like image 2
Byte_me Avatar answered Nov 03 '22 18:11

Byte_me