Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What do model.predict() and model.fit() do?

I'm going through this reinforcement learning tutorial and It's been really great so far but could someone please explain what

newQ = model.predict(new_state.reshape(1,64), batch_size=1)

and

model.fit(X_train, y_train, batch_size=batchSize, nb_epoch=1, verbose=1)

mean?

As in what do the arguments bach_size, nb_epoch and verbose do? I know neural networks so explaining in terms of that would be helpful.

You could also send me a link where the documentation of these functions can be found.

like image 200
Soham Avatar asked Jun 22 '16 15:06

Soham


1 Answers

First of all it surprises me that you could not find the documentation but I guess you just had bad luck while searching.

The documentation states for model.fit:

fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None)

  • batch_size: integer. Number of samples per gradient update.
  • nb_epoch: integer, the number of times to iterate over the training data arrays.
  • verbose: 0, 1, or 2. Verbosity mode. 0 = silent, 1 = verbose, 2 = one log line per epoch.

The batch_size parameter in case of model.predict is just the number of samples used for each prediction step. So calling model.predict one time consumes batch_size number of data samples. This helps for devices that can process large matrices quickly (such as GPUs).

like image 168
nemo Avatar answered Oct 16 '22 10:10

nemo