Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras Fit_generator Callback

I'm using fit_generator that reads data from a file, and when it gets to the end of file it loads data from the next file. I'm also using a stateful RNN in keras so that I need to manually reset the state, in this case every time the generator loads a new file. How can I achieve this?

The generator looks something like this:

def gendata():
    crow = 0
    path = 'somepath'
    df = pd.read_csv(path)
    while True:
        if x + l < len(df):
            yield df.iloc[x:x+l,:]
            x += l
        else:
            path = newpath(path)
            df = pd.read_csv(path)
            model.reset_states() # this line obviously doesn't work
like image 961
ndrue Avatar asked Apr 15 '26 13:04

ndrue


1 Answers

Just pass the generator a reference to the model:

def gendata(model):
    ...
    model.reset_states()

model.fit_generator(gendata(model), ...)
like image 158
Chris K Avatar answered Apr 17 '26 01:04

Chris K



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!