Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Input to `.fit()` should have rank 4

Tags:

python

keras

I learning data augmentation using keras. link:https://keras.io/preprocessing/image/ The link uses CNN but when i try using dense layers as given below I get error. The error is at datagen.fit(). Description says input should be rank 4. How to resolve?

#import dataset
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

#change shape from image to vector
X_train = X_train.reshape(50000, 32 * 32 * 3)
X_test = X_test.reshape(10000, 32 * 32 * 3)

#preprocess
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.0
X_test /= 255.0

#change labels from numeric to one hot encoded
Y_train = to_categorical(y_train, 10)
Y_test =  to_categorical(y_test, 10)

model = Sequential()
model.add(Dense(1024, input_shape=(3072, )))
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])


datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)


datagen.fit(X_train)

# fits the model on batches with real-time data augmentation:
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=epochs,verbose=1,
                        validation_data=datagen.flow(X_test, Y_test, batch_size=32))

Error

---------------------------------------------------------------------------
ValueError                         
---> 57 datagen.fit(X_train)


ValueError: Input to `.fit()` should have rank 4. Got array with shape: (50000, 3072)
like image 568
pranav nerurkar Avatar asked Sep 21 '25 06:09

pranav nerurkar


1 Answers

You reshaped your array. The ImageDataGenerator requires rank 4 input matrix(Image Index, height, widht, depth). Your Reshaping gives rank 2 input matrix. Hence the error. The fix would be to remove the reshaping then Add a CNN layer above the first Dense Layer(just suggestion).

like image 186
Ashmit Bhattarai Avatar answered Sep 22 '25 21:09

Ashmit Bhattarai