Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: What is the output of predict_generator?

The Keras documentation says it returns "A Numpy array of predictions." Using it on 496 image examples with 4 classes, I get a 4-dimensional array (496, 4, 4, 512). What are the other 2 dimensions? Eventually, I would like to have an array of X (examples) and an array of Y (labels).

img_width, img_height = 150, 150
top_model_weights_path = 'bottleneck_fc_model.h5'
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
nb_train_samples = 496
nb_validation_samples = 213
epochs = 50
batch_size = 16
number_of_classes = 3
datagen = ImageDataGenerator(rescale=1. / 255)

# build the VGG16 network (exclude last layer)
model = applications.VGG16(include_top=False, weights='imagenet')

# generate training data from image files
train_generator = datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False)

# predict bottleneck features on training data
bottleneck_features_train = model.predict_generator(
    train_generator, nb_train_samples // batch_size)
print(bottleneck_features_train.shape)

train_data = np.load(open('bottleneck_features_train.npy', 'rb'))
print(train_data.shape)
like image 773
Shahar Barak Avatar asked Oct 18 '22 14:10

Shahar Barak


1 Answers

What you're doing is extracting the bottleneck features from the images you're feeding to the model. The shape (496, 4, 4, 512) you're obtaining is (n_samples, feature_height, feature_width, feature:channels) You took out the dense layers of the model by passing

include_top=False

to explain graphically, you passed the samples through this model

VGG16 Model

without the last 4 layers. (you have different height and width because your staring image is 150x150 and not 224x224 like in standard VGG16)

What you obtain is not the prediction of the classes but a synthetic representation of the important features of the images.

To obtain what you seem to need you can modify the code like this

model = applications.VGG16(include_top=False, weights='imagenet')
for layer in model.layers:
    layer.trainable = False
model = Dense(512, activation='relu')(model) #512 is a parameter you can tweak, the higher, the more complex the model
model = Dense(number_of_classes, activation='softmax')(model)

Now you would call model.fit(X,Y) on the samples you're using to train the model giving it as X the 496 sample images and as Y the ground truth label you prepared.

After the training you can use model.predict to predict the classes you need.

like image 108
Mike Avatar answered Oct 21 '22 04:10

Mike