Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

ValueError: `decode_predictions` expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 7)

I am using VGG16 with keras for transfer learning (I have 7 classes in my new model) and as such I want to use the build-in decode_predictions method to output the predictions of my model. However, using the following code:

preds = model.predict(img)

decode_predictions(preds, top=3)[0]

I receive the following error message:

ValueError: decode_predictions expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 7)

Now I wonder why it expects 1000 when I only have 7 classes in my retrained model.

A similar question I found here on stackoverflow (Keras: ValueError: decode_predictions expects a batch of predictions ) suggests to include 'inlcude_top=True' upon model definition to solve this problem:

model = VGG16(weights='imagenet', include_top=True)

I have tried this, however it is still not working - giving me the same error as before. Any hint or suggestion on how to solve this issue is highly appreciated.

like image 536
AaronDT Avatar asked Mar 13 '18 14:03

AaronDT


1 Answers

i suspect you are using some pre-trained model, let's say for instance resnet50 and you are importing decode_predictions like this:

from keras.applications.resnet50 import decode_predictions

decode_predictions transform an array of (num_samples, 1000) probabilities to class name of original imagenet classes.

if you want to transer learning and classify between 7 different classes you need to do it like this:

base_model = resnet50 (weights='imagenet', include_top=False)

# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 7 classes
predictions = Dense(7, activation='softmax')(x) 
model = Model(inputs=base_model.input, outputs=predictions)
...

after fitting the model and calculate predictions you have to manually assign the class name to output number without using imported decode_predictions

like image 193
Ioannis Nasios Avatar answered Sep 17 '22 23:09

Ioannis Nasios