Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

keras model.fit_generator() several times slower than model.fit()

Even as of Keras 1.2.2, referencing merge, it does have multiprocessing included, but model.fit_generator() is still about 4-5x slower than model.fit() due to disk reading speed limitations. How can this be sped up, say through additional multiprocessing?

like image 945
mikal94305 Avatar asked Mar 07 '17 06:03

mikal94305


2 Answers

You may want to check out the workers and max_queue_size parameters of fit_generator() in the documentation. Essentially, more workers creates more threads for loading the data into the queue that feeds data to your network. There is a chance that filling the queue might cause memory problems, though, so you might want to decrease max_queue_size to avoid this.

like image 55
Mach_Zero Avatar answered Oct 04 '22 09:10

Mach_Zero


I had a similar problem where I switched to dask to load the data into memory rather than using a generator where I was using pandas. So, depending on your data size, if possible, load the data into memory and use the fit function.

like image 33
nafizh Avatar answered Oct 04 '22 10:10

nafizh