So I'm trying to use Keras' fit_generator with a custom data generator to feed into an LSTM network.
To illustrate the problem, I have created a toy example trying to predict the next number in a simple ascending sequence, and I use the Keras TimeseriesGenerator to create a Sequence instance:
WINDOW_LENGTH = 4
data = np.arange(0,100).reshape(-1,1)
data_gen = TimeseriesGenerator(data, data, length=WINDOW_LENGTH,
sampling_rate=1, batch_size=1)
I use a simple LSTM network:
data_dim = 1
input1 = Input(shape=(WINDOW_LENGTH, data_dim))
lstm1 = LSTM(100)(input1)
hidden = Dense(20, activation='relu')(lstm1)
output = Dense(data_dim, activation='linear')(hidden)
model = Model(inputs=input1, outputs=output)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
and train it using the fit_generator
function:
model.fit_generator(generator=data_gen,
steps_per_epoch=32,
epochs=10)
And this trains perfectly, and the model makes predictions as expected.
Now the problem is, in my non-toy situation I want to process the data coming out from the TimeseriesGenerator before feeding the data into the fit_generator
. As a step towards this, I create a generator function which just wraps the TimeseriesGenerator used previously.
def get_generator(data, targets, window_length = 5, batch_size = 32):
while True:
data_gen = TimeseriesGenerator(data, targets, length=window_length,
sampling_rate=1, batch_size=batch_size)
for i in range(len(data_gen)):
x, y = data_gen[i]
yield x, y
data_gen_custom = get_generator(data, data,
window_length=WINDOW_LENGTH, batch_size=1)
But now the strange thing is that when I train the model as before, but using this generator as the input,
model.fit_generator(generator=data_gen_custom,
steps_per_epoch=32,
epochs=10)
There is no error but the training error is all over the place (jumping up and down instead of consistently going down like it did with the other approach), and the model doesn't learn to make good predictions.
Any ideas what I'm doing wrong with my custom generator approach?
It could be because the object type is changed from Sequence
which is what a TimeseriesGenerator
is to a generic generator. The fit_generator
function treats these differently. A cleaner solution would be to inherit the class and override the processing bit:
class CustomGen(TimeseriesGenerator):
def __getitem__(self, idx):
x, y = super()[idx]
# do processing here
return x, y
And use this class like before as the rest of internal logic will remain the same.
I personally had a problem with the code by nuric. For some reason I had the error saying super not being subscriptable. Here is my possible fix. Let me known if this could possibly work?
class CustomGen(TimeseriesGenerator):
def __getitem__(self, idx):
x,y = super().__getitem__(idx)
return x, y
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With