Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

StopIteration: generator_output = next(output_generator)

I have the following code which I rewrite to work on a large scale dataset. I am using Python generator to Fit the model on data yielded batch-by-batch.

def subtract_mean_gen(x_source,y_source,avg_image,batch):
    batch_list_x=[]
    batch_list_y=[]
    for line,y in zip(x_source,y_source):
        x=line.astype('float32')
        x=x-avg_image
        batch_list_x.append(x)
        batch_list_y.append(y)
        if len(batch_list_x) == batch:
            yield (np.array(batch_list_x),np.array(batch_list_y))
            batch_list_x=[]
            batch_list_y=[] 

model = resnet.ResnetBuilder.build_resnet_18((img_channels, img_rows, img_cols), nb_classes)
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

val = subtract_mean_gen(X_test,Y_test,avg_image_test,batch_size)
model.fit_generator(subtract_mean_gen(X_train,Y_train,avg_image_train,batch_size), steps_per_epoch=X_train.shape[0]//batch_size,epochs=nb_epoch,validation_data = val,
                    validation_steps = X_test.shape[0]//batch_size)

I obtain the following error:

239/249 [===========================>..] - ETA: 60s - loss: 1.3318 - acc: 0.8330Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/utils/data_utils.py", line 560, in data_generator_task
    generator_output = next(self._generator)
StopIteration

240/249 [===========================>..] - ETA: 54s - loss: 1.3283 - acc: 0.8337Traceback (most recent call last):
  File "cifa10-copy.py", line 125, in <module>
    validation_steps = X_test.shape[0]//batch_size)
  File "/usr/local/lib/python2.7/dist-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1809, in fit_generator
    generator_output = next(output_generator)
StopIteration

I looked into a similar question posted here however, I am not able to resolve the error why StopIteration is raised.

like image 269
cswah Avatar asked Feb 09 '18 16:02

cswah


2 Answers

Generators for keras must be infinite:

def subtract_mean_gen(x_source,y_source,avg_image,batch):
    while True:
        batch_list_x=[]
        batch_list_y=[]
        for line,y in zip(x_source,y_source):
            x=line.astype('float32')
            x=x-avg_image
            batch_list_x.append(x)
            batch_list_y.append(y)
            if len(batch_list_x) == batch:
                yield (np.array(batch_list_x),np.array(batch_list_y))
                batch_list_x=[]
                batch_list_y=[] 

The error happens because keras tries to get a new batch, but your generator has already reached its end. (Even though you defined a correct number of steps, keras has a queue that will be trying to get more batches from the generator even if you are at the last step.)

Apparently, you've got a default queue size, which is 10 (the exception appears 10 batches before the end because the queue is trying to get a batch after the end).

like image 78
Daniel Möller Avatar answered Oct 16 '22 05:10

Daniel Möller


As the linked question you provided indicates, Keras Generators have to iterate indefinitely, so you can output elements to your training as long as you want. More info on that on this Github issue.

For that, you must do some modificaiton to your generator like:

def subtract_mean_gen(x_source,y_source,avg_image,batch):
batch_list_x=[]
batch_list_y=[]
while 1: #run forever, so you can generate elements indefinitely
    for line,y in zip(x_source,y_source):
        x=line.astype('float32')
        x=x-avg_image    
        batch_list_x.append(x)
        batch_list_y.append(y)
        if len(batch_list_x) == batch:
            yield (np.array(batch_list_x),np.array(batch_list_y))
            batch_list_x=[]
            batch_list_y=[]
like image 5
DarkCygnus Avatar answered Oct 16 '22 06:10

DarkCygnus