Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Use a generator for Keras model.fit_generator

I originally tried to use generator syntax when writing a custom generator for training a Keras model. So I yielded from __next__. However, when I would try to train my mode with model.fit_generator I would get an error that my generator was not an iterator. The fix was to change yield to return which also necessitated rejiggering the logic of __next__ to track state. It's quite cumbersome compared to letting yield do the work for me.

Is there a way I can make this work with yield? I will need to write several more iterators that will have to have very clunky logic if I have to use a return statement.

like image 338
doogFromMT Avatar asked Sep 29 '17 16:09

doogFromMT


1 Answers

I can't help debug your code since you didn't post it, but I abbreviated a custom data generator I wrote for a semantic segmentation project for you to use as a template:

def generate_data(directory, batch_size):
    """Replaces Keras' native ImageDataGenerator."""
    i = 0
    file_list = os.listdir(directory)
    while True:
        image_batch = []
        for b in range(batch_size):
            if i == len(file_list):
                i = 0
                random.shuffle(file_list)
            sample = file_list[i]
            i += 1
            image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)
            image_batch.append((image.astype(float) - 128) / 128)

        yield np.array(image_batch)

Usage:

model.fit_generator(
    generate_data('~/my_data', batch_size),
    steps_per_epoch=len(os.listdir('~/my_data')) // batch_size)
like image 51
Jessica Alan Avatar answered Sep 16 '22 16:09

Jessica Alan