I am looking for the best approach to train on larger-than-memory-data in Keras and currently noticing that the vanilla ImageDataGenerator tends to be slower than I would hope.
I have two networks training on the Kaggle cat's vs dogs dataset (25000 images):
1) this approach is exactly the code from: http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/
2) same as (1) but using an ImageDataGenerator instead of loading into memory the data
Note: for below, "preprocessing" means resizing, scaling, flattening
I find the following on my gtx970:
For network 1, it takes ~0s per epoch.
For network 2, it takes ~36s per epoch if the preprocessing is done in the data generator.
For network 2, it takes ~13s per epoch if preprocessing is done in a first-pass outside of the data generator.
Is this likely the speed limit for ImageDataGenerator (13s seems like the usual 10-100x difference between disk and ram...)? Are there approaches/mechanisms better suited for training on larger-than-memory-data when using Keras? e.g. Perhaps there is way to get the ImageDataGenerator in Keras to save its processed images after the first epoch?
Thanks!
I assume you already might have solved this, but nevertheless...
Keras image preprocessing has the option of saving the results by setting the save_to_dir
argument in the flow()
or flow_from_directory()
function:
https://keras.io/preprocessing/image/
In my understanding, problem is that augmented images are used only once in a training cycle of a model, not even across several epochs. So it's a huge waste of GPU cycles while CPU is struggling. I found following solution:
This approach most of the time keeps GPU busy, while being able to benefit from data augmentation. I use custom Sequence subclass to generate augmentation and fix classes imbalance at the same time.
EDIT: adding some code to clarify the idea
from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2
class StoppingFromFile(tensorflow.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if read_config_file('control.ini','ML','stop',globals()):
if stop is not None:
if stop==True or stop=='True':
logging.warning(f'Model should be stopped according to the control fole')
self.model.stop_training = True
class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
disk_saving_prob:float=.01,disk_example_nfiles:int=50):
"""
From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
If current class is too scarce, ensures that current frame has no duplicate final images.
If it's rich enough, ensures that current frame has no duplicate base images.
"""
logging.info(f'Got {len(images_and_classes)} classes.')
self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
self.images_and_classes=images_and_classes
self.num_class_samples=num_class_samples
self.augmentations_fn=augmentations_fn
self.preprocessing_fn=preprocessing_fn
self.is_validation=is_validation
self.frame_length=frame_length
self.batch_size = batch_size
self.class_sizes=class_sizes
self.input_size=input_size
self.aug_pipe_p=aug_pipe_p
self.aug_p=aug_p
self.images=None
self.epoch = 0
#print(f'got frame_length={self.frame_length}')
self._generate_data()
def __len__(self):
return int(np.ceil(len(self.images)/ float(self.batch_size)))
def __getitem__(self, idx):
a=idx * self.batch_size;b=a+self.batch_size
return self.images[a:b],self.labels[a:b]
def on_epoch_end(self):
import ast
self.epoch += 1
mydict={}
import pathlib
fname='control.json'
p = pathlib.Path(fname)
if p.is_file():
try:
with open (fname) as f:
mydict=json.load(f)
for var,val in mydict.items():
if hasattr(self,var):
converted = val #ast.literal_eval(val)
if converted is not None:
if getattr(self, var)!=converted:
setattr(self, var, converted)
print(f'{var} became {val}')
except Exception as e:
logging.error(str(e))
if self.epoch % self.frame_length == 0:
#print('generating data...')
self._generate_data()
def _add_sample(self,image,label):
from random import random
idx=self.indices[self.img_sent]
if self.disk_saving_prob>0:
if random()<self.disk_saving_prob:
self.cur_example_file+=1
if self.cur_example_file>self.disk_example_nfiles:
self.cur_example_file=1
Path(r'example_images/').mkdir(parents=True, exist_ok=True)
cv2.imwrite(f'example_images/test{self.cur_example_file}.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
if self.preprocessing_fn:
self.images[idx]=self.preprocessing_fn(image)
else:
self.images[idx]=image
self.labels[idx]=label
self.img_sent+=1
def _generate_data(self):
logging.info('Generating new set of augmented data...')
collect()
#del self.images
#del self.labels
#collect()
if self.num_class_samples:
expected_length=len(self.images_and_classes)*self.num_class_samples
else:
expected_length=sum(self.class_sizes.values())
if self.images is None:
self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
self.labels=np.empty((expected_length),np.int32)
self.indices=np.random.choice(expected_length, expected_length, replace=False)
self.img_sent=0
collect()
relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
#for each class
x,y=[],[]
nartificial=0
for label,images in tqdm(self.images_and_classes.items()):
if self.num_class_samples is None:
#Just all native samples without augmentations
for image in images:
self._add_sample(image,label)
else:
#if there are enough native samples
if len(images)>=self.num_class_samples:
#randomly select samples of this class which will participate in this frame of epochs
indices=np.random.choice(len(images), self.num_class_samples, replace=False)
#apply albumentations pipeline to selected samples
for idx in indices:
if not self.is_validation:
self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
else:
self._add_sample(images[idx],label)
else:
#------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
#------------------------------------------------------------------------------------------------------------------------------------------------------------------
hashes=set()
norig=0
while len(hashes)<self.num_class_samples:
if self.is_validation and norig<len(images):
#just include all originals first
image=images[norig]
else:
image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']
next_hash=np.sum(image)
if next_hash not in hashes or (self.is_validation and norig<=len(images)):
#print(f'Adding orig {norig} out of {self.num_class_samples}, hashes={hashes}')
self._add_sample(image,label)
if next_hash in hashes:
norig+=1
hashes.add(norig)
else:
hashes.add(next_hash)
nartificial+=1
#self.images=self.images[indices];self.labels=self.labels[indices]
logging.info(f'Generated {self.img_sent} samples ({nartificial} artificial)')
once I have images and classes loaded,
train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)
val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)
and after the model is instantiated, I do
model.fit(train_datagen,epochs=600,verbose=1,
validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With