Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Model.fit which supports generators (after fit_generator deprecation)

Tags:

I have got this deprecation warning while using Model.fit_generator in tensorflow:

WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators. 

How can I use Model.fit instead of Model.fit_generator?

like image 720
Behnam Avatar asked Dec 17 '19 18:12

Behnam


People also ask

What is the difference between model fit and model Fit_generator?

fit is used when the entire training dataset can fit into the memory and no data augmentation is applied. . fit_generator is used when either we have a huge dataset to fit into our memory or when data augmentation needs to be applied.

How does model fit work?

Model fitting is a measure of how well a machine learning model generalizes to similar data to that on which it was trained. A model that is well-fitted produces more accurate outcomes. A model that is overfitted matches the data too closely. A model that is underfitted doesn't match closely enough.

What does fit function do in keras?

fit method. Trains the model for a fixed number of epochs (iterations on a dataset). x: Input data.

What is Steps_per_epoch?

steps_per_epoch: Total number of steps (batches of samples) to yield from generator before declaring one epoch finished and starting the next epoch. It should typically be equal to the number of unique samples of your dataset divided by the batch size.


2 Answers

Model.fit_generator is deprecated starting from tensorflow 2.1.0 which is currently is in rc1. You can find the documentation for tf-2.1.0-rc1 here: https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit

As you can see the first argument of the Model.fit can take a generator so just pass it your generator.

like image 63
Wathek LOUED Avatar answered Oct 01 '22 07:10

Wathek LOUED


As mentioned in the documentation (emphasis mine):

x: Input data. It could be

  • A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
  • A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
  • A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
  • A tf.data dataset. Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights)
  • A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample weights). A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given below.

you can simply pass the generator to Model.fit as similar to Model.fit_generator

data_gen_train = ImageDataGenerator(rescale=1/255.)  data_gen_valid = ImageDataGenerator(rescale=1/255.)  train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")  valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")  model.fit(train_generator, epochs=2, validation_data=valid_generator)  
like image 21
Anurag Mishra Avatar answered Oct 01 '22 05:10

Anurag Mishra