Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Trained keras model much slower making its predictions than in training

Tags:

python

keras

I trained a keras model overnight, and got 75% accuracy which I am happy with right now. It has 60,000 samples, each with a sequence length of 700, and a vocabulary of 30. Each epoch takes about 10 minutes on my gpu. So that's 60,000 / 600 seconds which is roughly 100 samples per second, and that has to include back propagation. So I saved my hdf5 file and loaded it again.

<code>#Model:
model = Sequential() 
model.add(LSTM(128, input_shape=(X.shape[1], X.shap[2]), return_sequences=True)) model.add(Dropout(0.25)) model.add(LSTM(64)) model.add(Dropout(0.25)) model.add(Dense(y.shape[1], activation='softmax'))
</code>

When I then make my predictions it is taking more like 1 second per prediction which is 100 times slower than training. The predictions are good, I've looked at small batches and I can use them. The problem is that I need many 100,000s of them. 10ms second per prediction would work, 1 second won't.

Can anyone suggest ways of speeding up Keras predictions?

like image 310
photox Avatar asked Dec 26 '16 17:12

photox


People also ask

What does keras model predict return?

predict passes the input vector through the model and returns the output tensor for each datapoint. Since the last layer in your model is a single Dense neuron, the output for any datapoint is a single value. And since you didn't specify an activation for the last layer, it will default to linear activation.

Which keras function do you use for training a model?

Keras handles all of this with a single call of the 'fit' function, with the proper arguments. This tells Keras to train our network on the training dataset 'x_train' with corresponding labels 'y_val'. The small batches contain 64 images.


1 Answers

I think it's because Keras's default predict behavior is with batch size 32. As a result especially if you're using a GPU, the small batch sizes destroy the performance. If you just change the batch size to predict(X_test, batch_size=128) you'll get significantly faster performance.

like image 161
mf908 Avatar answered Sep 21 '22 16:09

mf908