The warnings.catch_warnings()
context manager is not thread safe. How do I use it in a parallel processing environment?
The code below solves a maximization problem using parallel processing with Python's multiprocessing
module. It takes a list of (immutable) widgets, partitions them up (see Efficient multiprocessing of massive, brute force maximization in Python 3), finds the maxima ("finalists") of all the partitions, and then finds the maximum ("champion") of those "finalists." If I understand my own code correctly (and I wouldn't be here if I did), I'm sharing memory with all the child processes to give them the input widgets, and multiprocessing
uses an operating-system-level pipe and pickling to send the finalist widgets back to the main process when the workers are done.
I want to catch the redundant widget warnings being caused by widgets' re-instantiation after the unpickling that happens when the widgets come out of the inter-process pipe. When widget objects instantiate, they validate their own data, emitting warnings from the Python standard warnings
module to tell the app's user that the widget suspects there is a problem with the user's input data. Because unpickling causes objects to instantiate, my understanding of the code implies that each widget object is reinstantiated exactly once if and only if it is a finalist after it comes out of the pipe -- see the next section to see why this isn't correct.
The widgets were already created before being frobnicated, so the user is already painfully aware of what input he got wrong and doesn't want to hear about it again. These are the warnings I'd like to catch with the warnings
module's catch_warnings()
context manager (i.e., a with
statement).
In my tests I've narrowed down when the superfluous warnings are being emitted to anywhere between what I've labeled below as Line A and Line B. What surprises me is that the warnings are being emitted in places other than just near output_queue.get()
. This implies to me that multiprocessing
sends the widgets to the workers using pickling.
The upshot is that putting a context manager created by warnings.catch_warnings()
even around everything from Line A to Line B and setting the right warnings filter inside this context does not catch the warnings. This implies to me that the warnings are being emitted in the worker processes. Putting this context manager around the worker code does not catch the warnings either.
This example omits the code for deciding if the problem size is too small to bother with forking processes, importing multiprocessing, and defining my_frobnal_counter
, and my_load_balancer
.
"Call `frobnicate(list_of_widgets)` to get the widget with the most frobnals" def frobnicate_parallel_worker(widgets, output_queue): resultant_widget = max(widgets, key=my_frobnal_counter) output_queue.put(resultant_widget) def frobnicate_parallel(widgets): output_queue = multiprocessing.Queue() # partitions: Generator yielding tuples of sets partitions = my_load_balancer(widgets) processes = [] # Line A: Possible start of where the warnings are coming from. for partition in partitions: p = multiprocessing.Process( target=frobnicate_parallel_worker, args=(partition, output_queue)) processes.append(p) p.start() finalists = [] for p in processes: finalists.append(output_queue.get()) # Avoid deadlocks in Unix by draining queue before joining processes for p in processes: p.join() # Line B: Warnings no longer possible after here. return max(finalists, key=my_frobnal_counter)
you can try to override the Process.run
method to use warnings.catch_warnings
.
>>> from multiprocessing import Process >>> >>> def yell(text): ... import warnings ... print 'about to yell %s' % text ... warnings.warn(text) ... >>> class CustomProcess(Process): ... def run(self, *args, **kwargs): ... import warnings ... with warnings.catch_warnings(): ... warnings.simplefilter("ignore") ... return Process.run(self, *args, **kwargs) ... >>> if __name__ == '__main__': ... quiet = CustomProcess(target=yell, args=('...not!',)) ... quiet.start() ... quiet.join() ... noisy = Process(target=yell, args=('AAAAAAaaa!',)) ... noisy.start() ... noisy.join() ... about to yell ...not! about to yell AAAAAAaaa! __main__:4: UserWarning: AAAAAAaaa! >>>
or you can use some of the internals... (__warningregistry__
)
>>> from multiprocessing import Process >>> import exceptions >>> def yell(text): ... import warnings ... print 'about to yell %s' % text ... warnings.warn(text) ... # not filtered ... warnings.warn('complimentary second warning.') ... >>> WARNING_TEXT = 'AAAAaaaaa!' >>> WARNING_TYPE = exceptions.UserWarning >>> WARNING_LINE = 4 >>> >>> class SelectiveProcess(Process): ... def run(self, *args, **kwargs): ... registry = globals().setdefault('__warningregistry__', {}) ... registry[(WARNING_TEXT, WARNING_TYPE, WARNING_LINE)] = True ... return Process.run(self, *args, **kwargs) ... >>> if __name__ == '__main__': ... p = SelectiveProcess(target=yell, args=(WARNING_TEXT,)) ... p.start() ... p.join() ... about to yell AAAAaaaaa! __main__:6: UserWarning: complimentary second warning. >>>
The unpickling would not cause the __init__
to be executed twice. I ran the following code on Windows, and it doesn't happen (each __init__
is run precisely once).
Therefore, you need to provide us with the code from my_load_balancer
and from widgets' class. At this point, your question simply doesn't provide enough information.
As a random guess, you might check whether my_load_balancer
makes copies of widgets, causing them to be instantiated once again.
import multiprocessing import collections "Call `frobnicate(list_of_widgets)` to get the widget with the most frobnals" def my_load_balancer(widgets): partitions = tuple(set() for _ in range(8)) for i, widget in enumerate(widgets): partitions[i % 8].add(widget) for partition in partitions: yield partition def my_frobnal_counter(widget): return widget.id def frobnicate_parallel_worker(widgets, output_queue): resultant_widget = max(widgets, key=my_frobnal_counter) output_queue.put(resultant_widget) def frobnicate_parallel(widgets): output_queue = multiprocessing.Queue() # partitions: Generator yielding tuples of sets partitions = my_load_balancer(widgets) processes = [] # Line A: Possible start of where the warnings are coming from. for partition in partitions: p = multiprocessing.Process( target=frobnicate_parallel_worker, args=(partition, output_queue)) processes.append(p) p.start() finalists = [] for p in processes: finalists.append(output_queue.get()) # Avoid deadlocks in Unix by draining queue before joining processes for p in processes: p.join() # Line B: Warnings no longer possible after here. return max(finalists, key=my_frobnal_counter) class Widget: id = 0 def __init__(self): print('initializing Widget {}'.format(self.id)) self.id = Widget.id Widget.id += 1 def __str__(self): return str(self.id) def __repr__(self): return str(self) def main(): widgets = [Widget() for _ in range(16)] result = frobnicate_parallel(widgets) print(result.id) if __name__ == '__main__': main()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With