I'm using fit_generator that reads data from a file, and when it gets to the end of file it loads data from the next file. I'm also using a stateful RNN in keras so that I need to manually reset the state, in this case every time the generator loads a new file. How can I achieve this?
The generator looks something like this:
def gendata():
crow = 0
path = 'somepath'
df = pd.read_csv(path)
while True:
if x + l < len(df):
yield df.iloc[x:x+l,:]
x += l
else:
path = newpath(path)
df = pd.read_csv(path)
model.reset_states() # this line obviously doesn't work
Just pass the generator a reference to the model:
def gendata(model):
...
model.reset_states()
model.fit_generator(gendata(model), ...)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With