Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Make Python unittest fail on exception from any thread

I am using the unittest framework to automate integration tests of multi-threaded python code, external hardware and embedded C. Despite my blatant abuse of a unittesting framework for integration testing, it works really well. Except for one problem: I need the test to fail if an exception is raised from any of the spawned threads. Is this possible with the unittest framework?

A simple but non-workable solution would be to either a) refactor the code to avoid multi-threading or b) test each thread separately. I cannot do that because the code interacts asynchronously with the external hardware. I have also considered implementing some kind of message passing to forward the exceptions to the main unittest thread. This would require significant testing-related changes to the code being tested, and I want to avoid that.

Time for an example. Can I modify the test script below to fail on the exception raised in my_thread without modifying the x.ExceptionRaiser class?

import unittest
import x

class Test(unittest.TestCase):
    def test_x(self):
        my_thread = x.ExceptionRaiser()
        # Test case should fail when thread is started and raises
        # an exception.
        my_thread.start()
        my_thread.join()

if __name__ == '__main__':
    unittest.main()
like image 329
Jakob Buron Avatar asked Sep 18 '12 20:09

Jakob Buron


2 Answers

At first, sys.excepthook looked like a solution. It is a global hook which is called every time an uncaught exception is thrown.

Unfortunately, this does not work. Why? well threading wraps your run function in code which prints the lovely tracebacks you see on screen (noticed how it always tells you Exception in thread {Name of your thread here}? this is how it's done).

Starting with Python 3.8, there is a function which you can override to make this work: threading.excepthook

... threading.excepthook() can be overridden to control how uncaught exceptions raised by Thread.run() are handled

So what do we do? Replace this function with our logic, and voilà:

For python >= 3.8

import traceback
import threading 
import os


class GlobalExceptionWatcher(object):
    def _store_excepthook(self, args):
        '''
        Uses as an exception handlers which stores any uncaught exceptions.
        '''
        self.__org_hook(args)
        formated_exc = traceback.format_exception(args.exc_type, args.exc_value, args.exc_traceback)
        self._exceptions.append('\n'.join(formated_exc))
        return formated_exc

    def __enter__(self):
        '''
        Register us to the hook.
        '''
        self._exceptions = []
        self.__org_hook = threading.excepthook
        threading.excepthook = self._store_excepthook

    def __exit__(self, type, value, traceback):
        '''
        Remove us from the hook, assure no exception were thrown.
        '''
        threading.excepthook = self.__org_hook
        if len(self._exceptions) != 0:
            tracebacks = os.linesep.join(self._exceptions)
            raise Exception(f'Exceptions in other threads: {tracebacks}')

For older versions of Python, this is a bit more complicated. Long story short, it appears that the threading nodule has an undocumented import which does something along the lines of:

threading._format_exc = traceback.format_exc

Not very surprisingly, this function is only called when an exception is thrown from a thread's run function.

So for python <= 3.7

import threading 
import os

class GlobalExceptionWatcher(object):
    def _store_excepthook(self):
        '''
        Uses as an exception handlers which stores any uncaught exceptions.
        '''
        formated_exc = self.__org_hook()
        self._exceptions.append(formated_exc)
        return formated_exc
        
    def __enter__(self):
        '''
        Register us to the hook.
        '''
        self._exceptions = []
        self.__org_hook = threading._format_exc
        threading._format_exc = self._store_excepthook
        
    def __exit__(self, type, value, traceback):
        '''
        Remove us from the hook, assure no exception were thrown.
        '''
        threading._format_exc = self.__org_hook
        if len(self._exceptions) != 0:
            tracebacks = os.linesep.join(self._exceptions)
            raise Exception('Exceptions in other threads: %s' % tracebacks)

Usage:

my_thread = x.ExceptionRaiser()
# will fail when thread is started and raises an exception.
with GlobalExceptionWatcher():
    my_thread.start()
    my_thread.join()
            

You still need to join yourself, but upon exit, the with-statement's context manager will check for any exception thrown in other threads, and will raise an exception appropriately.


THE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED

This is an undocumented, sort-of-horrible hack. I tested it on linux and windows, and it seems to work. Use it at your own risk.

like image 64
Ohad Avatar answered Sep 29 '22 02:09

Ohad


I've come across this problem myself, and the only solution I've been able to come up with is subclassing Thread to include an attribute for whether or not it terminates without an uncaught exception:

from threading import Thread

class ErrThread(Thread):
    """                                                                                                                                                                                               
    A subclass of Thread that will log store exceptions if the thread does                                                                                                                            
    not exit normally                                                                                                                                                                                 
    """
    def run(self):
        try:
            Thread.run(self)
        except Exception as self.err:
            pass
        else:
            self.err = None


class TaskQueue(object):
    """                                                                                                                                                                                               
    A utility class to run ErrThread objects in parallel and raises and exception                                                                                                                     
    in the event that *any* of them fail.                                                                                                                                                             
    """

    def __init__(self, *tasks):

        self.threads = []

        for t in tasks:
            try:
                self.threads.append(ErrThread(**t)) ## passing in a dict of target and args
            except TypeError:
                self.threads.append(ErrThread(target=t))

    def run(self):

        for t in self.threads:
            t.start()
        for t in self.threads:
            t.join()
            if t.err:
                raise Exception('Thread %s failed with error: %s' % (t.name, t.err))
like image 22
damzam Avatar answered Sep 29 '22 01:09

damzam