Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

On fit_generator() / fit() and thread-safety

Context

In order to use fit_generator() in Keras I use a generator-function like this pseudocode-one:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]

In Keras' fit_generator() function I want to use workers=4 and use_multiprocessing=True - Hence, I need a threadsafe generator.

In answers on stackoverflow like here or here or in the Keras docs, I read about creating a class inheriting from Keras.utils.Sequence() like this:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return ...

By using Sequences Keras does not throw any warning using multiple workes and multiprocessing; the generator is supposed to be threadsafe.

Anyhow, since I am using my custom function I stumbled upon Omer Zohars code provided on github which allows to make my generator() threadsafe by adding a decorator. The code looks like:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe."""
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g

Now I can do:

@threadsafe_generator
def generator(data):
    ...

The thing is: Using this version of a threadsafe generator Keras still emits a warning that the generator has to be threadsafe when using workers > 1 and use_multiprocessing=True and that this can be avoided by using Sequences.


My questions now are:

  1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?
  2. Is using the approach I choosed as threadsafe as using the generatorClass(Sequence)-version from the Keras-docs?
  3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?


Edit: In newer tensorflow/keras-versions (tf > 2) fit_generator() is deprecated. Instead, it is recommended to use fit() with the generator. However, the question still applies to fit() using a generator as well.

like image 251
Markus Avatar asked Jun 04 '19 09:06

Markus


1 Answers

During my research on this I came across some information answering my questions.

Note: As updated in the question in newer tensorflow/keras-versions (tf > 2) fit_generator() is deprecated. Instead, it is recommended to use fit() with the generator. However, the answer still applies to fit() using a generator as well.


1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?

Taken from Keras' gitRepo (training_generators.py) I found in lines 46-52 the following:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
    warnings.warn(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the `keras.utils.Sequence'
                    ' class.'))

The definition of is_sequence() taken from training_utils.py in lines 624-635 is:

def is_sequence(seq):
    """Determine if an object follows the Sequence API.
    # Arguments
        seq: a possible Sequence object
    # Returns
        boolean, whether the object follows the Sequence API.
    """
    # TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
    return (getattr(seq, 'use_sequence_api', False)
            or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))

Regarding this piece of code Keras only checks if a passed generator is a Keras-sequence (or rather uses Keras' sequence API) and does not check if a generator is threadsafe in general.


2. Is using the approach I choosed as threadsafe as using the generatorClass(Sequence)-version from the Keras-docs?

As Omer Zohar has shown on gitHub his decorator is threadsafe - I don't see any reason why it shouldn't be as threadsafe for Keras (even though Keras will warn as shown in 1.). The implementation of thread.Lock() can be concidered as threadsafe according to the docs:

A factory function that returns a new primitive lock object. Once a thread has acquired it, subsequent attempts to acquire it block, until it is released; any thread may release it.

The generator is also picklable, which can be tested like (see this SO-Q&A here for further information):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
    for yielded_data in generator(data):
        pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)

Resuming this, I would even suggest to implement thread.Lock() when you extend Keras' Sequence() like:

import threading

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.lock = threading.Lock()   #Set self.lock

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        with self.lock:                #Use self.lock
            batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

            return ...

Edit 24/04/2020:

By using self.lock = threading.Lock() you might run into the following error:

TypeError: can't pickle _thread.lock objects

In case this happens try to replace with self.lock: inside __getitem__ with with threading.Lock(): and comment out / delete the self.lock = threading.Lock() inside the __init__.

It seems there are some problems when storing the lock-object inside a class (see for example this Q&A).


3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?

During my research I did not encounter any other method. Of course I cannot say this with 100% certainty.

like image 83
Markus Avatar answered Oct 25 '22 06:10

Markus