I built a simple generator that yields a tuple(inputs, targets)
with only single items in the inputs
and targets
lists. Basically, it is crawling the data set, one sample item at a time.
I pass this generator into:
model.fit_generator(my_generator(), nb_epoch=10, samples_per_epoch=1, max_q_size=1 # defaults to 10 )
I get that:
nb_epoch
is the number of times the training batch will be runsamples_per_epoch
is the number of samples trained with per epochBut what is max_q_size
for and why would it default to 10? I thought the purpose of using a generator was to batch data sets into reasonable chunks, so why the additional queue?
our . fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model. For the number of epochs specified(10 in our case) the process is repeated.
Update July 2021: For TensorFlow 2.2+ users, just use the . fit method for your projects. The . fit_generator method will be deprecated in future releases of TensorFlow as the .
The Keras methods fit_generator, evaluate_generator, and predict_generator have an argument called workers . By setting workers to 2 , 4 , 8 or multiprocessing. cpu_count() instead of the default 1 , Keras will spawn threads (or processes with the use_multiprocessing argument) when ingesting data batches.
This simply defines the maximum size of the internal training queue which is used to "precache" your samples from generator. It is used during generation of the the queues
def generator_queue(generator, max_q_size=10, wait_time=0.05, nb_worker=1): '''Builds a threading queue out of a data generator. Used in `fit_generator`, `evaluate_generator`, `predict_generator`. ''' q = queue.Queue() _stop = threading.Event() def data_generator_task(): while not _stop.is_set(): try: if q.qsize() < max_q_size: try: generator_output = next(generator) except ValueError: continue q.put(generator_output) else: time.sleep(wait_time) except Exception: _stop.set() raise generator_threads = [threading.Thread(target=data_generator_task) for _ in range(nb_worker)] for thread in generator_threads: thread.daemon = True thread.start() return q, _stop
In other words you have a thread filling the queue up to given, maximum capacity directly from your generator, while (for example) training routine consumes its elements (and sometimes waits for the completion)
while samples_seen < samples_per_epoch: generator_output = None while not _stop.is_set(): if not data_gen_queue.empty(): generator_output = data_gen_queue.get() break else: time.sleep(wait_time)
and why default of 10? No particular reason, like most of the defaults - it simply makes sense, but you could use different values too.
Construction like this suggests, that authors thought about expensive data generators, which might take time to execture. For example consider downloading data over a network in generator call - then it makes sense to precache some next batches, and download next ones in parallel for the sake of efficiency and to be robust to network errors etc.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With