Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Calling "fit_generator()" multiple times in Keras

I have a generator function which generates tuples of (inputs, targets) on which my model is trained using the fit_generator() method in Keras.

My dataset is divided into 9 equal parts. I wish to perform a leave-one-out cross validation on the dataset using the fit_generator() method and keep the learned parameters of the previous training intact.

My question is that will calling fit_generator() multiple times on the model make it re-learn its learned parameters on the previous train and validation sets from scratch or will it keep those learned parameters intact leading to improvement of accuracy?

After a little digging I found that the fit() method in Keras retains the learned parameters as over here Calling "fit" multiple times in Keras but I'm not sure if the same happens for fit_generator() and if it does can it be used for cross-validation of data.

The pseudo-code I'm thinking of implementing to achieve the cross-validation is as follows:

class DatasetGenerator(Sequence):
    def __init__(validation_id, mode):
        #Some code

    def __getitem__():
        #The generator function

        #Some code

        return (inputs, targets)

for id in range(9):

    train_set = DatasetGenerator(id, 'train') 
    #train_set contains all 8 parts leaving the id part out for validation.

    validation_set = DatasetGenerator(id, 'val')
    #val_set contains the id part.

    history = model.fit_generator(train_set, epochs = 10, steps_per_epoch = 24000, validation_data = val_set, validation_steps = 3000)

print('History Dict:', history.history)
results = model.evaluate_generator(test_set, steps=steps)
print('Test loss, acc:', results)

Will the model keep the learned parameters intact and improve upon them for each iteration of the for loop?

like image 233
Rohan Lekhwani Avatar asked Oct 27 '25 10:10

Rohan Lekhwani


2 Answers

fit and fit_generator behave the same in that regard, calling them again will resume training from the previously trained weights.

Also note that what you are trying to do is not cross-validation, as to do real cross-validation, you train one model for each fold, and the models are completely independent, not continued from training of the previous fold.

like image 170
Dr. Snoopy Avatar answered Oct 28 '25 22:10

Dr. Snoopy


As far as I know it will keep the previous trained params. Also, I think what you are trying to do can be done by modifying the on_epoch_end() method of Sequence. Could be something like this:

class DatasetGenerator(Sequence):
    def __init__(self, id, mode):
        self.id = id
        self.mode = mode
        self.current_epoch=0
        #some code

    def __getitem__(self, idx):
        id = self.id
        #Some code
        return (inputs, targets)

    def on_epoch_end():
        self.current_epoch += 1
        if self.current_epoch % 10 == 0:
            self.id += 1
like image 43
meowongac Avatar answered Oct 28 '25 23:10

meowongac



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!