Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras ImageDataGenerator Slow

I am looking for the best approach to train on larger-than-memory-data in Keras and currently noticing that the vanilla ImageDataGenerator tends to be slower than I would hope.

I have two networks training on the Kaggle cat's vs dogs dataset (25000 images):

1) this approach is exactly the code from: http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/

2) same as (1) but using an ImageDataGenerator instead of loading into memory the data

Note: for below, "preprocessing" means resizing, scaling, flattening

I find the following on my gtx970:

For network 1, it takes ~0s per epoch.

For network 2, it takes ~36s per epoch if the preprocessing is done in the data generator.

For network 2, it takes ~13s per epoch if preprocessing is done in a first-pass outside of the data generator.

Is this likely the speed limit for ImageDataGenerator (13s seems like the usual 10-100x difference between disk and ram...)? Are there approaches/mechanisms better suited for training on larger-than-memory-data when using Keras? e.g. Perhaps there is way to get the ImageDataGenerator in Keras to save its processed images after the first epoch?

Thanks!

like image 694
John Cast Avatar asked Dec 10 '16 03:12

John Cast


2 Answers

I assume you already might have solved this, but nevertheless...

Keras image preprocessing has the option of saving the results by setting the save_to_dir argument in the flow() or flow_from_directory() function:

https://keras.io/preprocessing/image/

like image 101
petezurich Avatar answered Sep 25 '22 22:09

petezurich


In my understanding, problem is that augmented images are used only once in a training cycle of a model, not even across several epochs. So it's a huge waste of GPU cycles while CPU is struggling. I found following solution:

  1. I generate as many augmentations in RAM as I can
  2. I use them for training across a frame of epochs, 10 to 30, whatever it takes to get a noticeable convergence
  3. after that I generate new batch of augmented images (by implementing on_epoch_end) and process goes on.

This approach most of the time keeps GPU busy, while being able to benefit from data augmentation. I use custom Sequence subclass to generate augmentation and fix classes imbalance at the same time.

EDIT: adding some code to clarify the idea

from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2

class StoppingFromFile(tensorflow.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if read_config_file('control.ini','ML','stop',globals()):        
            if stop is not None:        
                if stop==True or stop=='True':
                    logging.warning(f'Model should be stopped according to the control fole')
                    self.model.stop_training = True

class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
    def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
                 num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
                disk_saving_prob:float=.01,disk_example_nfiles:int=50):
        """
            From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
            If current class is too scarce, ensures that current frame has no duplicate final images.
            If it's rich enough, ensures that current frame has no duplicate base images.
        
        """
        logging.info(f'Got {len(images_and_classes)} classes.')
        self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
        
        self.images_and_classes=images_and_classes        
        self.num_class_samples=num_class_samples
        self.augmentations_fn=augmentations_fn
        self.preprocessing_fn=preprocessing_fn
        
        self.is_validation=is_validation
        self.frame_length=frame_length                    
        self.batch_size = batch_size      
        self.class_sizes=class_sizes
        self.input_size=input_size        
        self.aug_pipe_p=aug_pipe_p
        self.aug_p=aug_p        
        self.images=None
        self.epoch = 0
        #print(f'got frame_length={self.frame_length}')
        self._generate_data()
        

    def __len__(self):
        return int(np.ceil(len(self.images)/ float(self.batch_size)))

    def __getitem__(self, idx):
        a=idx * self.batch_size;b=a+self.batch_size
        return self.images[a:b],self.labels[a:b]
    
    def on_epoch_end(self):
        import ast
        self.epoch += 1    
        mydict={}

        import pathlib
        fname='control.json'
        p = pathlib.Path(fname)
        if p.is_file():
            try:
                with open (fname) as f:
                    mydict=json.load(f)
                for var,val in mydict.items():
                    if hasattr(self,var):
                        converted = val #ast.literal_eval(val)
                        if converted is not None:
                            if getattr(self, var)!=converted:
                                setattr(self, var, converted)                                        
                                print(f'{var} became {val}')
            except Exception as e:
                logging.error(str(e))
        if self.epoch % self.frame_length == 0:
            #print('generating data...')
            self._generate_data()
            
    def _add_sample(self,image,label):
        from random import random
        idx=self.indices[self.img_sent]
        
        if self.disk_saving_prob>0:
            if random()<self.disk_saving_prob:
                self.cur_example_file+=1
                if self.cur_example_file>self.disk_example_nfiles:
                    self.cur_example_file=1
                Path(r'example_images/').mkdir(parents=True, exist_ok=True)
                cv2.imwrite(f'example_images/test{self.cur_example_file}.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
        
        if self.preprocessing_fn: 
            self.images[idx]=self.preprocessing_fn(image)
        else:
            self.images[idx]=image
        
        self.labels[idx]=label
        self.img_sent+=1        
        
    def _generate_data(self):
        logging.info('Generating new set of augmented data...')
        
        collect()
        #del self.images
        #del self.labels        
        #collect()
        
        if self.num_class_samples:
            expected_length=len(self.images_and_classes)*self.num_class_samples
        else:
            expected_length=sum(self.class_sizes.values())        
            
        if self.images is None:
            self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
            self.labels=np.empty((expected_length),np.int32)
        
        self.indices=np.random.choice(expected_length, expected_length, replace=False)
        self.img_sent=0
        
        
        collect()
        
        relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
        maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
        
        #for each class
        x,y=[],[]
        nartificial=0
        for label,images in tqdm(self.images_and_classes.items()):
            if self.num_class_samples is None:
                #Just all native samples without augmentations
                for image in images:
                    self._add_sample(image,label)                        
            else:
                #if there are enough native samples
                if len(images)>=self.num_class_samples:
                    #randomly select samples of this class which will participate in this frame of epochs                
                    indices=np.random.choice(len(images), self.num_class_samples, replace=False)
                    #apply albumentations pipeline to selected samples

                    for idx in indices:
                        if not self.is_validation:
                            self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
                        else:
                            self._add_sample(images[idx],label)
                                                    
                else:
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    # Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    hashes=set()
                    norig=0
                    while len(hashes)<self.num_class_samples:
                        if self.is_validation and norig<len(images):
                            #just include all originals first
                            image=images[norig]
                        else:
                            image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']                                                      
                        next_hash=np.sum(image)
                        if next_hash not in hashes or (self.is_validation and norig<=len(images)):                        
                            
                            #print(f'Adding orig {norig} out of {self.num_class_samples}, hashes={hashes}')
                            
                            self._add_sample(image,label)
                            if next_hash in hashes:
                                norig+=1
                                hashes.add(norig)
                            else:
                                hashes.add(next_hash)
                                nartificial+=1  
                                
        
        #self.images=self.images[indices];self.labels=self.labels[indices]                              
        
        logging.info(f'Generated {self.img_sent} samples ({nartificial} artificial)')

once I have images and classes loaded,

train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
                          input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
    augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)

val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
                                        input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
    augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)

and after the model is instantiated, I do

model.fit(train_datagen,epochs=600,verbose=1,
          validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
          callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)
like image 24
Anatoly Alekseev Avatar answered Sep 23 '22 22:09

Anatoly Alekseev