Keras fit_generator()
has a parameter pickle_safe
which defaults to False
.
Training can run faster if it is pickle_safe, and accordingly set the flag to True
?
According to Kera's docs:
pickle_safe: If True, use process based threading. Note that because this implementation relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes.
I don't understand exactly what this is saying.
How can I determine if my arguments are pickle_safe
or not ??
If it's relevant:
- I'm passing in a custom generator
- the generator function takes arguments: X_train, y_train, batch_size, p_keep;
they are of type np.array, int, float)
- I'm not using a GPU
- Also, I'm using Keras 1.2.1, though I believe this argument behaves the same as in keras 2
I have no familiarity with keras
, but from a glance at the documentation, pickle_safe
just means that the tuples produced by your generator must be "picklable".
pickle
is a standard python module that is used to serialize and unserialize objects. The standard multiprocessing
implementation uses the pickle
mechanism to share objects between different processes -- since the two processes don't share the same address space, they cannot directly see the same python objects. So, to send objects from process A to process B, they're pickled in A (which produces a sequence of bytes in a specific well-known format), the pickled format is then sent via an interprocess-communication mechanism to B, and unpickled in B, producing a copy of A's original object in B's address space.
So, to discover if your objects are picklable, just invoke, say, pickle.dumps
on them.
>>> import pickle
>>> class MyObject:
... def __init__(self, a, b, c):
... self.a = a
... self.b = b
... self.c = c
...
>>> foo = MyObject(1, 2, 3)
>>> pickle.dumps(foo)
b'\x80\x03c__main__\nMyObject\nq\x00)\x81q\x01}q\x02(X\x01\x00\x00\x00cq\x03K\x03X\x01\x00\x00\x00aq\x04K\x01X\x01\x00\x00\x00bq\x05K\x02ub.'
>>>
dumps
produces a byte string. We can now reconstitute the foo
object from the byte string as bar
using loads
:
>>> foo_pick = pickle.dumps(foo)
>>> bar = pickle.loads(foo_pick)
>>> bar
<__main__.MyObject object at 0x7f5e262ece48>
>>> bar.a, bar.b, bar.c
(1, 2, 3)
If something is not picklable, you'll get an exception. For example, lambdas can't be pickled:
>>> class MyOther:
... def __init__(self, a, b, c):
... self.a = a
... self.b = b
... self.c = c
... self.printer = lambda: print(self.a, self.b, self.c)
...
>>> other = MyOther(1, 2, 3)
>>> other_pick = pickle.dumps(other)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: Can't pickle local object 'MyOther.__init__.<locals>.<lambda>'
See the documentation for more info: https://docs.python.org/3.5/library/pickle.html?highlight=pickle#what-can-be-pickled-and-unpickled
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With