Hi I don't understand the keras fit_generator docs.
I hope my confusion is rational.
There is a batch_size
and also the concept of training in in batches. Using model_fit()
, I specify a batch_size
of 128.
To me this means that my dataset will be fed in 128 samples at a time, thereby greatly alleviating memory. It should allow a 100 million sample dataset to be trained as long as I have got the time to wait. After all, keras is only "working with" 128 samples at a time. Right?
But I highly suspect that for specifying the batch_size
alone doesn't do what I want whatsoever. Tons of memory is still being used. For my goals I need to train in batches of 128 examples each.
So I am guessing this is what fit_generator
does. I really want to ask why doesn't batch_size
actually work as it's name suggests?
More importantly, if fit_generator
is needed, where do I specify the batch_size
? The docs say to loop indefinitely.
A generator loops over every row once. How do I loop over 128 samples at a time and remember where I last stopped and recall it the next time that keras asks for the next batch's starting row number (would be row 129 after first batch is done).
fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model. For the number of epochs specified(10 in our case) the process is repeated.
fit_generator() method: The model is trained on batch-by-batch data generated by the Python constructor. Returns the `History` item.
You pass your whole dataset at once in fit method. Also, use it if you can load whole data into your memory (small dataset). In fit_generator() , you don't pass the x and y directly, instead they come from a generator.
If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of a dataset, generators, or keras.
You will need to handle the batch size somehow inside the generator. Here is an example to generate random batches:
import numpy as np
data = np.arange(100)
data_lab = data%2
wholeData = np.array([data, data_lab])
wholeData = wholeData.T
def data_generator(all_data, batch_size = 20):
while True:
idx = np.random.randint(len(all_data), size=batch_size)
# Assuming the last column contains labels
batch_x = all_data[idx, :-1]
batch_y = all_data[idx, -1]
# Return a tuple of (Xs,Ys) to feed the model
yield(batch_x, batch_y)
print([x for x in data_generator(wholeData)])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With