Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is fit_generator in Keras supposed to reset the generator after each epoch?

I am trying to use fit_generator with a custom generator to read in data that's too big for memory. There are 1.25 million rows I want to train on, so I have the generator yield 50,000 rows at a time. fit_generator has 25 steps_per_epoch, which I thought would bring in those 1.25MM per epoch. I added a print statement so that I could see how much offset the process was doing, and I found that it exceeded the max when it got a few steps into epoch 2. There are a total of 1.75 million records in that file, and once it passes 10 steps, it gets an index error in the create_feature_matrix call (because it brings in no rows).

def get_next_data_batch():
    import gc
    nrows = 50000
    skiprows = 0

    while True:
        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
        print(skiprows)
        x,y = create_feature_matrix(d)
        yield x,y
        skiprows = skiprows + nrows
        gc.collect()
get_data = get_next_data_batch()

... set up a Keras NN ...

model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)

Am I using fit_generator wrong or is there some change that needs to be made to my custom generator to get this to work?

like image 280
user4446237 Avatar asked Feb 11 '18 06:02

user4446237


People also ask

What does model Fit_generator do?

our . fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model. For the number of epochs specified(10 in our case) the process is repeated.

What is Steps_per_epoch?

steps_per_epoch: Total number of steps (batches of samples) to yield from generator before declaring one epoch finished and starting the next epoch. It should typically be equal to the number of unique samples of your dataset divided by the batch size.


1 Answers

No - fit_generator doesn't reset generator, it's simply continuing calling it. In order to achieve the behavior you want you may try the following:

def get_next_data_batch(nb_of_calls_before_reset=25):
    import gc
    nrows = 50000
    skiprows = 0
    nb_calls = 0

    while True:
        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
        print(skiprows)
        x,y = create_feature_matrix(d)
        yield x,y
        nb_calls += 1
        if nb_calls == nb_of_calls_before_reset:
            skiprows = 0
        else:
            skiprows = skiprows + nrows
        gc.collect()
like image 113
Marcin Możejko Avatar answered Oct 01 '22 06:10

Marcin Możejko