Even as of Keras 1.2.2, referencing merge, it does have multiprocessing included, but model.fit_generator()
is still about 4-5x slower than model.fit()
due to disk reading speed limitations. How can this be sped up, say through additional multiprocessing?
You may want to check out the workers
and max_queue_size
parameters of fit_generator()
in the documentation. Essentially, more workers
creates more threads for loading the data into the queue that feeds data to your network. There is a chance that filling the queue might cause memory problems, though, so you might want to decrease max_queue_size
to avoid this.
I had a similar problem where I switched to dask to load the data into memory rather than using a generator where I was using pandas. So, depending on your data size, if possible, load the data into memory and use the fit function.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With